Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c2092862
提交
c2092862
authored
4月 03, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
5月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename Function.fn to Function.vm
上级
d39b852a
隐藏空白字符变更
内嵌
并排
正在显示
19 个修改的文件
包含
151 行增加
和
194 行删除
+151
-194
types.py
aesara/compile/function/types.py
+81
-126
profiling.py
aesara/compile/profiling.py
+2
-2
basic.py
aesara/graph/basic.py
+2
-2
check_blas.py
aesara/misc/check_blas.py
+1
-1
printing.py
aesara/printing.py
+1
-1
rng_mrg.py
aesara/sandbox/rng_mrg.py
+2
-2
op.py
aesara/scan/op.py
+18
-18
opt.py
aesara/scan/opt.py
+2
-0
faq.rst
doc/faq.rst
+2
-2
debug_faq.rst
doc/tutorial/debug_faq.rst
+2
-2
profiling.rst
doc/tutorial/profiling.rst
+2
-2
profiling_example_out.prof
doc/tutorial/profiling_example_out.prof
+1
-1
test_types.py
tests/compile/function/test_types.py
+7
-7
test_numba.py
tests/link/test_numba.py
+4
-4
test_numba_performance.py
tests/link/test_numba_performance.py
+2
-2
test_vm.py
tests/link/test_vm.py
+16
-16
test_rng_mrg.py
tests/sandbox/test_rng_mrg.py
+1
-1
test_basic.py
tests/scan/test_basic.py
+4
-4
test_conv.py
tests/tensor/nnet/test_conv.py
+1
-1
没有找到文件。
aesara/compile/function/types.py
浏览文件 @
c2092862
...
@@ -9,7 +9,7 @@ import logging
...
@@ -9,7 +9,7 @@ import logging
import
time
import
time
import
warnings
import
warnings
from
itertools
import
chain
from
itertools
import
chain
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Type
import
numpy
as
np
import
numpy
as
np
...
@@ -34,6 +34,10 @@ from aesara.link.basic import Container
...
@@ -34,6 +34,10 @@ from aesara.link.basic import Container
from
aesara.link.utils
import
raise_with_op
from
aesara.link.utils
import
raise_with_op
if
TYPE_CHECKING
:
from
aesara.link.vm
import
VM
_logger
=
logging
.
getLogger
(
"aesara.compile.function.types"
)
_logger
=
logging
.
getLogger
(
"aesara.compile.function.types"
)
...
@@ -271,42 +275,45 @@ DUPLICATE = object()
...
@@ -271,42 +275,45 @@ DUPLICATE = object()
class
Function
:
class
Function
:
"""
r"""A class that wraps the execution of a `VM` making it easier for use as a "function".
Type of the functions returned by aesara.function or
aesara.FunctionMaker.create.
`Function` is the callable object that does computation. It has the storage
`Function` is the callable object that does computation. It has the storage
of inputs and outputs, performs the packing and unpacking of inputs and
of inputs and outputs, performs the packing and unpacking of inputs and
return values. It implements the square-bracket indexing so that you can
return values. It implements the square-bracket indexing so that you can
look up the value of a symbolic node.
look up the value of a symbolic node.
Functions are copyable via
{{{fn.copy()}}} and {{{copy.copy(fn)}}}
.
Functions are copyable via
`Function.copy` and the `copy.copy` interface
.
When a function is copied, this instance is duplicated. Contrast with
When a function is copied, this instance is duplicated. Contrast with
self.maker (instance of `FunctionMaker`) that is shared between copies.
self.maker (instance of `FunctionMaker`) that is shared between copies.
The meaning of copying a function is that the containers and their current
The meaning of copying a function is that the containers and their current
values will all be duplicated. This requires that mutable inputs be
values will all be duplicated. This requires that mutable inputs be
copied, whereas immutable inputs may be shared between copies.
copied, whereas immutable inputs may be shared between copies.
A Function instance is hashable, on the basis of its memory
A Function instance is hashable, on the basis of its memory address (its
address (its id).
id).
A Function instance is only equal to itself.
A Function instance is only equal to itself.
A Function instance may be serialized using the `pickle` or
A Function instance may be serialized using the `pickle` or
`cPickle` modules. This will save all default inputs, the graph,
`cPickle` modules. This will save all default inputs, the graph,
and WRITEME to the pickle file.
and WRITEME to the pickle file.
A Function instance have a ``trust_input`` field that default to
A `Function` instance has a `Function.trust_input` field that defaults to
False. When True, we don't do extra check of the input to give
``False``. When ``True``, the `Function` will skip all checks on the
better error message. In some case, python code will still return
inputs.
the good results if you pass a python or numpy scalar instead of a
numpy tensor. C code should raise an error if you pass an object
of the wrong type.
Attributes
Attributes
----------
----------
finder
finder
Dictionary mapping several kinds of things to containers.
We set an entry in finder for:
- the index of the input
- the variable instance the input is based on
- the name of the input
All entries map to the container or to DUPLICATE if an ambiguity
is detected.
inv_finder
inv_finder
Reverse lookup of `finder`. It maps containers to `SymbolicInput`\s.
"""
"""
...
@@ -321,111 +328,59 @@ class Function:
...
@@ -321,111 +328,59 @@ class Function:
If the value is 'raise', then an AliasedMemoryError will be raised
If the value is 'raise', then an AliasedMemoryError will be raised
if aliased storage is detected during pickle.dump.
if aliased storage is detected during pickle.dump.
"""
input_storage
=
None
"""
List of Container instances.
"""
output_storage
=
None
"""
List of Container instances.
"""
indices
=
None
"""
List of (SymbolicInput, indices, [SymbolicInput,...]),
one tuple for each input.
The first tuple element is the SymbolicInput object for the corresponding
function input.
The second and third tuple elements are used only by Kits, which
are deprecated.
"""
defaults
=
None
"""
List of 3-tuples, one 3-tuple for each input.
Tuple element 0: Bool: Is this input required at each function call?
Tuple element 1: Bool: Should this inputs value be reverted after
each call?
Tuple element 2: Any: The value associated with this input.
"""
unpack_single
=
None
"""
Bool: for outputs lists of length 1, should the 0'th element be
returned directly?
"""
return_none
=
None
"""
Bool: whether the function should return None or not.
"""
maker
=
None
"""
FunctionMaker instance.
"""
fn
=
None
"""
A function that evaluates the graph. Typically a linker's make_thunk method
created this function.
"""
finder
=
None
"""
Dictionary mapping several kinds of things to containers.
We set an entry in finder for:
- the index of the input
- the variable instance the input is based on
- the name of the input
All entries map to the container or to DUPLICATE if an ambiguity
is detected.
"""
inv_finder
=
None
"""
Dict. Reverse lookup of `finder`.
It maps container -> SymbolicInput
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
fn
,
vm
:
"VM"
,
input_storage
,
input_storage
,
output_storage
,
output_storage
,
indices
,
indices
,
outputs
,
outputs
,
defaults
,
defaults
,
unpack_single
,
unpack_single
:
bool
,
return_none
,
return_none
:
bool
,
output_keys
,
output_keys
,
maker
,
maker
:
"FunctionMaker"
,
name
=
None
,
name
:
Optional
[
str
]
=
None
,
):
):
self
.
fn
=
fn
"""
Parameters
----------
vm
A `VM` instance that evaluates the graph when called.
input_storage
List of storage cells for each input.
output_storage
List of storage cells for each output.
indices
List of ``(SymbolicInput, indices, [SymbolicInput,...])``, one
tuple for each input. The first tuple element is the `SymbolicInput`
object for the corresponding function input. The second and third
tuple elements are used only by Kits, which are deprecated.
outputs
TODO
defaults
List of 3-tuples, one 3-tuple for each input.
Tuple element 0: ``bool``. Is this input required at each function
call?
Tuple element 1: ``bool``. Should this inputs value be reverted
after each call?
Tuple element 2: ``Any``. The value associated with this input.
unpack_single
For outputs lists of length 1, should the 0'th element be
returned directly?
return_none
Whether the function should return ``None`` or not.
output_keys
TODO
maker
The `FunctionMaker` that created this instance.
name
A string name.
"""
# TODO: Rename to `vm`
self
.
vm
=
vm
self
.
input_storage
=
input_storage
self
.
input_storage
=
input_storage
self
.
output_storage
=
output_storage
self
.
output_storage
=
output_storage
self
.
indices
=
indices
self
.
indices
=
indices
...
@@ -441,7 +396,7 @@ class Function:
...
@@ -441,7 +396,7 @@ class Function:
self
.
output_keys
=
output_keys
self
.
output_keys
=
output_keys
# See if we have any mutable / borrow inputs
# See if we have any mutable / borrow inputs
# TODO: this only need to be set if there is more th
en 1
input
# TODO: this only need to be set if there is more th
an one
input
self
.
_check_for_aliased_inputs
=
False
self
.
_check_for_aliased_inputs
=
False
for
i
in
maker
.
inputs
:
for
i
in
maker
.
inputs
:
# If the input is a shared variable, the memory region is
# If the input is a shared variable, the memory region is
...
@@ -575,7 +530,7 @@ class Function:
...
@@ -575,7 +530,7 @@ class Function:
# TODO: Get rid of all this `expanded_inputs` nonsense
# TODO: Get rid of all this `expanded_inputs` nonsense
assert
len
(
self
.
maker
.
expanded_inputs
)
==
len
(
self
.
input_storage
)
assert
len
(
self
.
maker
.
expanded_inputs
)
==
len
(
self
.
input_storage
)
# This is used only when `
fn
.need_update_inputs` is `False`, because
# This is used only when `
vm
.need_update_inputs` is `False`, because
# we're using one of the VM objects and it is putting updates back into
# we're using one of the VM objects and it is putting updates back into
# the input containers all by itself.
# the input containers all by itself.
self
.
n_returned_outputs
=
len
(
self
.
output_storage
)
-
sum
(
self
.
n_returned_outputs
=
len
(
self
.
output_storage
)
-
sum
(
...
@@ -752,7 +707,7 @@ class Function:
...
@@ -752,7 +707,7 @@ class Function:
# Construct new storage_map that map new variable to old storage,
# Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one
# so that the ensuing function shares storage with the original one
storage_map
=
self
.
fn
.
storage_map
storage_map
=
self
.
vm
.
storage_map
new_storage_map
=
{}
new_storage_map
=
{}
# TODO: We could share the output storage, but we must make sure
# TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This
# 2 different function call won't override each other values. This
...
@@ -1015,24 +970,24 @@ class Function:
...
@@ -1015,24 +970,24 @@ class Function:
t0_fn
=
time
.
time
()
t0_fn
=
time
.
time
()
try
:
try
:
outputs
=
(
outputs
=
(
self
.
fn
()
self
.
vm
()
if
output_subset
is
None
if
output_subset
is
None
else
self
.
fn
(
output_subset
=
output_subset
)
else
self
.
vm
(
output_subset
=
output_subset
)
)
)
except
Exception
:
except
Exception
:
restore_defaults
()
restore_defaults
()
if
hasattr
(
self
.
fn
,
"position_of_error"
):
if
hasattr
(
self
.
vm
,
"position_of_error"
):
# this is a new vm-provided function or c linker
# this is a new vm-provided function or c linker
# they need this because the exception manipulation
# they need this because the exception manipulation
# done by raise_with_op is not implemented in C.
# done by raise_with_op is not implemented in C.
thunk
=
None
thunk
=
None
if
hasattr
(
self
.
fn
,
"thunks"
):
if
hasattr
(
self
.
vm
,
"thunks"
):
thunk
=
self
.
fn
.
thunks
[
self
.
fn
.
position_of_error
]
thunk
=
self
.
vm
.
thunks
[
self
.
vm
.
position_of_error
]
raise_with_op
(
raise_with_op
(
self
.
maker
.
fgraph
,
self
.
maker
.
fgraph
,
node
=
self
.
fn
.
nodes
[
self
.
fn
.
position_of_error
],
node
=
self
.
vm
.
nodes
[
self
.
vm
.
position_of_error
],
thunk
=
thunk
,
thunk
=
thunk
,
storage_map
=
getattr
(
self
.
fn
,
"storage_map"
,
None
),
storage_map
=
getattr
(
self
.
vm
,
"storage_map"
,
None
),
)
)
else
:
else
:
# old-style linkers raise their own exceptions
# old-style linkers raise their own exceptions
...
@@ -1056,7 +1011,7 @@ class Function:
...
@@ -1056,7 +1011,7 @@ class Function:
# if we are allowing garbage collection, remove the
# if we are allowing garbage collection, remove the
# output reference from the internal storage cells
# output reference from the internal storage cells
if
getattr
(
self
.
fn
,
"allow_gc"
,
False
):
if
getattr
(
self
.
vm
,
"allow_gc"
,
False
):
assert
len
(
self
.
output_storage
)
==
len
(
self
.
maker
.
fgraph
.
outputs
)
assert
len
(
self
.
output_storage
)
==
len
(
self
.
maker
.
fgraph
.
outputs
)
for
o_container
,
o_variable
in
zip
(
for
o_container
,
o_variable
in
zip
(
self
.
output_storage
,
self
.
maker
.
fgraph
.
outputs
self
.
output_storage
,
self
.
maker
.
fgraph
.
outputs
...
@@ -1068,7 +1023,7 @@ class Function:
...
@@ -1068,7 +1023,7 @@ class Function:
# TODO: Get rid of this and `expanded_inputs`, since all the VMs now
# TODO: Get rid of this and `expanded_inputs`, since all the VMs now
# perform the updates themselves
# perform the updates themselves
if
getattr
(
self
.
fn
,
"need_update_inputs"
,
True
):
if
getattr
(
self
.
vm
,
"need_update_inputs"
,
True
):
# Update the inputs that have an update function
# Update the inputs that have an update function
for
input
,
storage
in
reversed
(
for
input
,
storage
in
reversed
(
list
(
zip
(
self
.
maker
.
expanded_inputs
,
self
.
input_storage
))
list
(
zip
(
self
.
maker
.
expanded_inputs
,
self
.
input_storage
))
...
@@ -1092,8 +1047,8 @@ class Function:
...
@@ -1092,8 +1047,8 @@ class Function:
if
profile
:
if
profile
:
profile
.
fct_callcount
+=
1
profile
.
fct_callcount
+=
1
profile
.
fct_call_time
+=
dt_call
profile
.
fct_call_time
+=
dt_call
if
hasattr
(
self
.
fn
,
"update_profile"
):
if
hasattr
(
self
.
vm
,
"update_profile"
):
self
.
fn
.
update_profile
(
profile
)
self
.
vm
.
update_profile
(
profile
)
if
profile
.
ignore_first_call
:
if
profile
.
ignore_first_call
:
profile
.
reset
()
profile
.
reset
()
profile
.
ignore_first_call
=
False
profile
.
ignore_first_call
=
False
...
@@ -1137,10 +1092,10 @@ class Function:
...
@@ -1137,10 +1092,10 @@ class Function:
"""
"""
# 1.no allow_gc return False
# 1.no allow_gc return False
# 2.has allow_gc, if allow_gc is False, return True
# 2.has allow_gc, if allow_gc is False, return True
if
not
getattr
(
self
.
fn
,
"allow_gc"
,
True
):
if
not
getattr
(
self
.
vm
,
"allow_gc"
,
True
):
for
key
in
self
.
fn
.
storage_map
:
for
key
in
self
.
vm
.
storage_map
:
if
not
isinstance
(
key
,
Constant
):
if
not
isinstance
(
key
,
Constant
):
self
.
fn
.
storage_map
[
key
][
0
]
=
None
self
.
vm
.
storage_map
[
key
][
0
]
=
None
for
node
in
self
.
nodes_with_inner_function
:
for
node
in
self
.
nodes_with_inner_function
:
if
hasattr
(
node
.
fn
,
"free"
):
if
hasattr
(
node
.
fn
,
"free"
):
...
...
aesara/compile/profiling.py
浏览文件 @
c2092862
...
@@ -217,7 +217,7 @@ class ProfileStats:
...
@@ -217,7 +217,7 @@ class ProfileStats:
#
#
vm_call_time
=
0.0
vm_call_time
=
0.0
# Total time spent in Function.
fn
.__call__
# Total time spent in Function.
vm
.__call__
#
#
apply_time
=
None
apply_time
=
None
...
@@ -781,7 +781,7 @@ class ProfileStats:
...
@@ -781,7 +781,7 @@ class ProfileStats:
)
)
if
self
.
fct_call_time
>
0
:
if
self
.
fct_call_time
>
0
:
print
(
print
(
f
" Time in Function.
fn
.__call__: {self.vm_call_time}s ({100 * self.vm_call_time / self.fct_call_time:.3f}
%
)"
,
f
" Time in Function.
vm
.__call__: {self.vm_call_time}s ({100 * self.vm_call_time / self.fct_call_time:.3f}
%
)"
,
file
=
file
,
file
=
file
,
)
)
local_time
=
sum
(
self
.
apply_time
.
values
())
local_time
=
sum
(
self
.
apply_time
.
values
())
...
...
aesara/graph/basic.py
浏览文件 @
c2092862
...
@@ -1139,9 +1139,9 @@ def clone_replace(
...
@@ -1139,9 +1139,9 @@ def clone_replace(
Parameters
Parameters
----------
----------
output
: Aesara Variables (or Aesara expressions)
output
Aesara expression that represents the computational graph.
Aesara expression that represents the computational graph.
replace
: dict
replace
Dictionary describing which subgraphs should be replaced by what.
Dictionary describing which subgraphs should be replaced by what.
rebuild_kwds
rebuild_kwds
Keywords to `rebuild_collect_shared`.
Keywords to `rebuild_collect_shared`.
...
...
aesara/misc/check_blas.py
浏览文件 @
c2092862
...
@@ -59,7 +59,7 @@ def execute(execute=True, verbose=True, M=2000, N=2000, K=2000, iters=10, order=
...
@@ -59,7 +59,7 @@ def execute(execute=True, verbose=True, M=2000, N=2000, K=2000, iters=10, order=
if
any
(
x
.
op
.
__class__
.
__name__
==
"Gemm"
for
x
in
f
.
maker
.
fgraph
.
toposort
()):
if
any
(
x
.
op
.
__class__
.
__name__
==
"Gemm"
for
x
in
f
.
maker
.
fgraph
.
toposort
()):
c_impl
=
[
c_impl
=
[
hasattr
(
thunk
,
"cthunk"
)
hasattr
(
thunk
,
"cthunk"
)
for
node
,
thunk
in
zip
(
f
.
fn
.
nodes
,
f
.
fn
.
thunks
)
for
node
,
thunk
in
zip
(
f
.
vm
.
nodes
,
f
.
vm
.
thunks
)
if
node
.
op
.
__class__
.
__name__
==
"Gemm"
if
node
.
op
.
__class__
.
__name__
==
"Gemm"
]
]
assert
len
(
c_impl
)
==
1
assert
len
(
c_impl
)
==
1
...
...
aesara/printing.py
浏览文件 @
c2092862
...
@@ -222,7 +222,7 @@ def debugprint(
...
@@ -222,7 +222,7 @@ def debugprint(
results_to_print
.
extend
(
obj
.
maker
.
fgraph
.
outputs
)
results_to_print
.
extend
(
obj
.
maker
.
fgraph
.
outputs
)
profile_list
.
extend
([
obj
.
profile
for
item
in
obj
.
maker
.
fgraph
.
outputs
])
profile_list
.
extend
([
obj
.
profile
for
item
in
obj
.
maker
.
fgraph
.
outputs
])
if
print_storage
:
if
print_storage
:
smap
.
extend
([
obj
.
fn
.
storage_map
for
item
in
obj
.
maker
.
fgraph
.
outputs
])
smap
.
extend
([
obj
.
vm
.
storage_map
for
item
in
obj
.
maker
.
fgraph
.
outputs
])
else
:
else
:
smap
.
extend
([
None
for
item
in
obj
.
maker
.
fgraph
.
outputs
])
smap
.
extend
([
None
for
item
in
obj
.
maker
.
fgraph
.
outputs
])
topo
=
obj
.
maker
.
fgraph
.
toposort
()
topo
=
obj
.
maker
.
fgraph
.
toposort
()
...
...
aesara/sandbox/rng_mrg.py
浏览文件 @
c2092862
...
@@ -75,7 +75,7 @@ def multMatVect(v, A, m1, B, m2):
...
@@ -75,7 +75,7 @@ def multMatVect(v, A, m1, B, m2):
f
.
input_storage
[
3
]
.
storage
[
0
]
=
B
f
.
input_storage
[
3
]
.
storage
[
0
]
=
B
f
.
input_storage
[
4
]
.
storage
[
0
]
=
v
[
3
:]
f
.
input_storage
[
4
]
.
storage
[
0
]
=
v
[
3
:]
f
.
input_storage
[
5
]
.
storage
[
0
]
=
m2
f
.
input_storage
[
5
]
.
storage
[
0
]
=
m2
f
.
fn
()
f
.
vm
()
r
=
f
.
output_storage
[
0
]
.
storage
[
0
]
r
=
f
.
output_storage
[
0
]
.
storage
[
0
]
return
r
return
r
...
@@ -829,7 +829,7 @@ class MRG_RandomStream:
...
@@ -829,7 +829,7 @@ class MRG_RandomStream:
v
=
rval
[
i
-
1
]
v
=
rval
[
i
-
1
]
f
.
input_storage
[
1
]
.
storage
[
0
]
=
v
[:
3
]
f
.
input_storage
[
1
]
.
storage
[
0
]
=
v
[:
3
]
f
.
input_storage
[
4
]
.
storage
[
0
]
=
v
[
3
:]
f
.
input_storage
[
4
]
.
storage
[
0
]
=
v
[
3
:]
f
.
fn
()
f
.
vm
()
rval
[
i
]
=
f
.
output_storage
[
0
]
.
storage
[
0
]
rval
[
i
]
=
f
.
output_storage
[
0
]
.
storage
[
0
]
if
inc_rstate
:
if
inc_rstate
:
...
...
aesara/scan/op.py
浏览文件 @
c2092862
...
@@ -1594,8 +1594,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1594,8 +1594,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
from
aesara.scan.utils
import
InnerFunctionError
from
aesara.scan.utils
import
InnerFunctionError
# TODO: Extract `Capsule` object and use that
# TODO: Extract `Capsule` object and use that
# c_thunk = getattr(self.fn.
fn
.thunks[0], "cthunk", None)
# c_thunk = getattr(self.fn.
vm
.thunks[0], "cthunk", None)
# if len(self.fn.
fn
.thunks) == 1 and c_thunk:
# if len(self.fn.
vm
.thunks) == 1 and c_thunk:
# thunk_capsule = c_thunk.cthunk
# thunk_capsule = c_thunk.cthunk
# # We need to perform the following after calling
# # We need to perform the following after calling
# # the thunk function:
# # the thunk function:
...
@@ -1633,20 +1633,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1633,20 +1633,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
outputs
,
outputs
,
outer_output_dtypes
,
outer_output_dtypes
,
outer_output_ndims
,
outer_output_ndims
,
self
.
fn
.
fn
,
self
.
fn
.
vm
,
)
)
except
InnerFunctionError
as
exc
:
except
InnerFunctionError
as
exc
:
exc_type
=
type
(
exc
.
args
[
0
])
exc_type
=
type
(
exc
.
args
[
0
])
exc_value
=
exc
.
args
[
0
]
exc_value
=
exc
.
args
[
0
]
exc_trace
=
exc
.
args
[
1
]
exc_trace
=
exc
.
args
[
1
]
if
hasattr
(
self
.
fn
.
fn
,
"position_of_error"
)
and
hasattr
(
if
hasattr
(
self
.
fn
.
vm
,
"position_of_error"
)
and
hasattr
(
self
.
fn
.
fn
,
"thunks"
self
.
fn
.
vm
,
"thunks"
):
):
raise_with_op
(
raise_with_op
(
self
.
fn
.
maker
.
fgraph
,
self
.
fn
.
maker
.
fgraph
,
self
.
fn
.
fn
.
nodes
[
self
.
fn
.
fn
.
position_of_error
],
self
.
fn
.
vm
.
nodes
[
self
.
fn
.
vm
.
position_of_error
],
self
.
fn
.
fn
.
thunks
[
self
.
fn
.
fn
.
position_of_error
],
self
.
fn
.
vm
.
thunks
[
self
.
fn
.
vm
.
position_of_error
],
exc_info
=
(
exc_type
,
exc_value
,
exc_trace
),
exc_info
=
(
exc_type
,
exc_value
,
exc_trace
),
)
)
else
:
else
:
...
@@ -1661,8 +1661,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1661,8 +1661,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
profile
.
callcount
+=
1
profile
.
callcount
+=
1
profile
.
nbsteps
+=
n_steps
profile
.
nbsteps
+=
n_steps
profile
.
call_time
+=
t_call
profile
.
call_time
+=
t_call
if
hasattr
(
self
.
fn
.
fn
,
"update_profile"
):
if
hasattr
(
self
.
fn
.
vm
,
"update_profile"
):
self
.
fn
.
fn
.
update_profile
(
profile
)
self
.
fn
.
vm
.
update_profile
(
profile
)
except
(
ImportError
,
MissingGXX
):
except
(
ImportError
,
MissingGXX
):
p
=
self
.
perform
p
=
self
.
perform
...
@@ -1795,7 +1795,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1795,7 +1795,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_output_storage
=
self
.
fn
.
output_storage
inner_output_storage
=
self
.
fn
.
output_storage
old_inner_output_storage
=
[
None
]
*
len
(
inner_output_storage
)
old_inner_output_storage
=
[
None
]
*
len
(
inner_output_storage
)
old_inner_output_data
=
[
None
]
*
len
(
inner_output_storage
)
old_inner_output_data
=
[
None
]
*
len
(
inner_output_storage
)
fn
=
self
.
fn
.
fn
vm
=
self
.
fn
.
vm
offset
=
(
offset
=
(
info
.
n_seqs
info
.
n_seqs
+
sum
(
+
sum
(
...
@@ -1938,18 +1938,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1938,18 +1938,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
t0_fn
=
time
.
time
()
t0_fn
=
time
.
time
()
try
:
try
:
fn
()
vm
()
except
Exception
:
except
Exception
:
if
hasattr
(
fn
,
"position_of_error"
):
if
hasattr
(
vm
,
"position_of_error"
):
# this is a new vm-provided function or c linker
# this is a new vm-provided function or c linker
# they need this because the exception manipulation
# they need this because the exception manipulation
# done by raise_with_op is not implemented in C.
# done by raise_with_op is not implemented in C.
if
hasattr
(
fn
,
"thunks"
):
if
hasattr
(
vm
,
"thunks"
):
# For the CVM
# For the CVM
raise_with_op
(
raise_with_op
(
self
.
fn
.
maker
.
fgraph
,
self
.
fn
.
maker
.
fgraph
,
fn
.
nodes
[
fn
.
position_of_error
],
vm
.
nodes
[
vm
.
position_of_error
],
fn
.
thunks
[
fn
.
position_of_error
],
vm
.
thunks
[
vm
.
position_of_error
],
)
)
else
:
else
:
# For the c linker
# For the c linker
...
@@ -1957,7 +1957,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1957,7 +1957,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# temps values So for now, we just don't print
# temps values So for now, we just don't print
# the extra shapes/strides info
# the extra shapes/strides info
raise_with_op
(
raise_with_op
(
self
.
fn
.
maker
.
fgraph
,
fn
.
nodes
[
fn
.
position_of_error
]
self
.
fn
.
maker
.
fgraph
,
vm
.
nodes
[
vm
.
position_of_error
]
)
)
else
:
else
:
# old-style linkers raise their own exceptions
# old-style linkers raise their own exceptions
...
@@ -2200,8 +2200,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -2200,8 +2200,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
profile
.
nbsteps
+=
n_steps
profile
.
nbsteps
+=
n_steps
profile
.
call_time
+=
t_call
profile
.
call_time
+=
t_call
profile
.
vm_call_time
+=
t_fn
profile
.
vm_call_time
+=
t_fn
if
hasattr
(
self
.
fn
.
fn
,
"update_profile"
):
if
hasattr
(
self
.
fn
.
vm
,
"update_profile"
):
self
.
fn
.
fn
.
update_profile
(
profile
)
self
.
fn
.
vm
.
update_profile
(
profile
)
self
.
t_call
=
t_call
self
.
t_call
=
t_call
self
.
t_fn
=
t_fn
self
.
t_fn
=
t_fn
...
...
aesara/scan/opt.py
浏览文件 @
c2092862
...
@@ -751,6 +751,8 @@ def add_nitsot_outputs(
...
@@ -751,6 +751,8 @@ def add_nitsot_outputs(
new_outputs_inner
,
new_outputs_inner
,
)
->
Tuple
[
Apply
,
Dict
[
Variable
,
Variable
]]:
)
->
Tuple
[
Apply
,
Dict
[
Variable
,
Variable
]]:
assert
isinstance
(
old_scan_node
.
op
,
Scan
)
nb_new_outs
=
len
(
new_outputs_inner
)
nb_new_outs
=
len
(
new_outputs_inner
)
# Create the initial values for the new nitsot outputs
# Create the initial values for the new nitsot outputs
...
...
doc/faq.rst
浏览文件 @
c2092862
...
@@ -141,8 +141,8 @@ with
...
@@ -141,8 +141,8 @@ with
Also, for small Aesara functions, you can remove more Python overhead by
Also, for small Aesara functions, you can remove more Python overhead by
making an Aesara function that does not take any input. You can use shared
making an Aesara function that does not take any input. You can use shared
variables to achieve this. Then you can call it like this: ``f.
fn
()`` or
variables to achieve this. Then you can call it like this: ``f.
vm
()`` or
``f.
fn
(n_calls=N)`` to speed it up. In the last case, only the last
``f.
vm
(n_calls=N)`` to speed it up. In the last case, only the last
function output (out of N calls) is returned.
function output (out of N calls) is returned.
You can also use the ``C`` linker that will put all nodes in the same C
You can also use the ``C`` linker that will put all nodes in the same C
...
...
doc/tutorial/debug_faq.rst
浏览文件 @
c2092862
...
@@ -140,9 +140,9 @@ Running the above code generates the following error message:
...
@@ -140,9 +140,9 @@ Running the above code generates the following error message:
File "test1.py", line 31, in <module>
File "test1.py", line 31, in <module>
f(np.random.random((5, 10)))
f(np.random.random((5, 10)))
File "PATH_TO_AESARA/aesara/compile/function/types.py", line 605, in __call__
File "PATH_TO_AESARA/aesara/compile/function/types.py", line 605, in __call__
self.
fn.thunks[self.fn
.position_of_error])
self.
vm.thunks[self.vm
.position_of_error])
File "PATH_TO_AESARA/aesara/compile/function/types.py", line 595, in __call__
File "PATH_TO_AESARA/aesara/compile/function/types.py", line 595, in __call__
outputs = self.
fn
()
outputs = self.
vm
()
ValueError: Shape mismatch: x has 10 cols (and 5 rows) but y has 20 rows (and 10 cols)
ValueError: Shape mismatch: x has 10 cols (and 5 rows) but y has 20 rows (and 10 cols)
Apply node that caused the error: Dot22(x, DimShuffle{1,0}.0)
Apply node that caused the error: Dot22(x, DimShuffle{1,0}.0)
Inputs types: [TensorType(float64, (None, None)), TensorType(float64, (None, None))]
Inputs types: [TensorType(float64, (None, None)), TensorType(float64, (None, None))]
...
...
doc/tutorial/profiling.rst
浏览文件 @
c2092862
...
@@ -52,8 +52,8 @@ function. aesara.function() has an optional parameter ``name`` that
...
@@ -52,8 +52,8 @@ function. aesara.function() has an optional parameter ``name`` that
defaults to None. Change it to something else to help you profile many
defaults to None. Change it to something else to help you profile many
Aesara functions. In that section, we also see the number of times the
Aesara functions. In that section, we also see the number of times the
function was called (1) and the total time spent in all those
function was called (1) and the total time spent in all those
calls. The time spent in
Function.fn.__call__
and in thunks is useful
calls. The time spent in
:meth:`Function.vm.__call__`
and in thunks is useful
to understand Aesara overhead.
to understand Aesara
's
overhead.
Also, we see the time spent in the two parts of the compilation
Also, we see the time spent in the two parts of the compilation
process: optimization (modify the graph to make it more stable/faster)
process: optimization (modify the graph to make it more stable/faster)
...
...
doc/tutorial/profiling_example_out.prof
浏览文件 @
c2092862
...
@@ -2,7 +2,7 @@ Function profiling
...
@@ -2,7 +2,7 @@ Function profiling
==================
==================
Message: None
Message: None
Time in 1 calls to Function.__call__: 5.698204e-05s
Time in 1 calls to Function.__call__: 5.698204e-05s
Time in Function.
fn
.__call__: 1.192093e-05s (20.921%)
Time in Function.
vm
.__call__: 1.192093e-05s (20.921%)
Time in thunks: 6.198883e-06s (10.879%)
Time in thunks: 6.198883e-06s (10.879%)
Total compile time: 3.642474e+00s
Total compile time: 3.642474e+00s
Aesara Optimizer time: 7.326508e-02s
Aesara Optimizer time: 7.326508e-02s
...
...
tests/compile/function/test_types.py
浏览文件 @
c2092862
...
@@ -346,8 +346,8 @@ class TestFunction:
...
@@ -346,8 +346,8 @@ class TestFunction:
cpy
=
ori
.
copy
(
share_memory
=
True
)
cpy
=
ori
.
copy
(
share_memory
=
True
)
# Test if memories shared
# Test if memories shared
storage_map_ori
=
ori
.
fn
.
storage_map
storage_map_ori
=
ori
.
vm
.
storage_map
storage_map_cpy
=
cpy
.
fn
.
storage_map
storage_map_cpy
=
cpy
.
vm
.
storage_map
fgraph_cpy
=
cpy
.
maker
.
fgraph
fgraph_cpy
=
cpy
.
maker
.
fgraph
# Assert intermediate and Constants storages are shared.
# Assert intermediate and Constants storages are shared.
...
@@ -424,11 +424,11 @@ class TestFunction:
...
@@ -424,11 +424,11 @@ class TestFunction:
# 2. SharedVariable is updatable -> values did update(z == 5)
# 2. SharedVariable is updatable -> values did update(z == 5)
# 1. sharedvariable is swap -> Rpl sharedvariables share storage
# 1. sharedvariable is swap -> Rpl sharedvariables share storage
names
=
map_SV
.
keys
()
names
=
map_SV
.
keys
()
for
key
in
cpy
.
fn
.
storage_map
:
for
key
in
cpy
.
vm
.
storage_map
:
if
key
.
name
in
names
:
if
key
.
name
in
names
:
assert
(
assert
(
map_SV
[
key
.
name
]
.
container
.
storage
[
0
]
map_SV
[
key
.
name
]
.
container
.
storage
[
0
]
==
cpy
.
fn
.
storage_map
[
key
][
0
]
==
cpy
.
vm
.
storage_map
[
key
][
0
]
)
)
second_time
=
True
second_time
=
True
...
@@ -688,18 +688,18 @@ class TestFunction:
...
@@ -688,18 +688,18 @@ class TestFunction:
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
func
=
function
([
x
],
x
+
1
)
func
=
function
([
x
],
x
+
1
)
func
.
fn
.
allow_gc
=
False
func
.
vm
.
allow_gc
=
False
func
([
1
])
func
([
1
])
check_list
=
[]
check_list
=
[]
for
key
,
val
in
func
.
fn
.
storage_map
.
items
():
for
key
,
val
in
func
.
vm
.
storage_map
.
items
():
if
not
isinstance
(
key
,
Constant
):
if
not
isinstance
(
key
,
Constant
):
check_list
.
append
(
val
)
check_list
.
append
(
val
)
assert
any
(
val
[
0
]
for
val
in
check_list
)
assert
any
(
val
[
0
]
for
val
in
check_list
)
func
.
free
()
func
.
free
()
for
key
,
val
in
func
.
fn
.
storage_map
.
items
():
for
key
,
val
in
func
.
vm
.
storage_map
.
items
():
if
not
isinstance
(
key
,
Constant
):
if
not
isinstance
(
key
,
Constant
):
assert
val
[
0
]
is
None
assert
val
[
0
]
is
None
...
...
tests/link/test_numba.py
浏览文件 @
c2092862
...
@@ -3505,7 +3505,7 @@ def test_config_options_parallel():
...
@@ -3505,7 +3505,7 @@ def test_config_options_parallel():
with
config
.
change_flags
(
numba__vectorize_target
=
"parallel"
):
with
config
.
change_flags
(
numba__vectorize_target
=
"parallel"
):
aesara_numba_fn
=
function
([
x
],
x
*
2
,
mode
=
numba_mode
)
aesara_numba_fn
=
function
([
x
],
x
*
2
,
mode
=
numba_mode
)
numba_mul_fn
=
aesara_numba_fn
.
fn
.
jit_fn
.
py_func
.
__globals__
[
"mul"
]
numba_mul_fn
=
aesara_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"mul"
]
assert
numba_mul_fn
.
targetoptions
[
"parallel"
]
is
True
assert
numba_mul_fn
.
targetoptions
[
"parallel"
]
is
True
...
@@ -3514,7 +3514,7 @@ def test_config_options_fastmath():
...
@@ -3514,7 +3514,7 @@ def test_config_options_fastmath():
with
config
.
change_flags
(
numba__fastmath
=
True
):
with
config
.
change_flags
(
numba__fastmath
=
True
):
aesara_numba_fn
=
function
([
x
],
x
*
2
,
mode
=
numba_mode
)
aesara_numba_fn
=
function
([
x
],
x
*
2
,
mode
=
numba_mode
)
numba_mul_fn
=
aesara_numba_fn
.
fn
.
jit_fn
.
py_func
.
__globals__
[
"mul"
]
numba_mul_fn
=
aesara_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"mul"
]
assert
numba_mul_fn
.
targetoptions
[
"fastmath"
]
is
True
assert
numba_mul_fn
.
targetoptions
[
"fastmath"
]
is
True
...
@@ -3523,12 +3523,12 @@ def test_config_options_cached():
...
@@ -3523,12 +3523,12 @@ def test_config_options_cached():
with
config
.
change_flags
(
numba__cache
=
True
):
with
config
.
change_flags
(
numba__cache
=
True
):
aesara_numba_fn
=
function
([
x
],
x
*
2
,
mode
=
numba_mode
)
aesara_numba_fn
=
function
([
x
],
x
*
2
,
mode
=
numba_mode
)
numba_mul_fn
=
aesara_numba_fn
.
fn
.
jit_fn
.
py_func
.
__globals__
[
"mul"
]
numba_mul_fn
=
aesara_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"mul"
]
assert
not
isinstance
(
assert
not
isinstance
(
numba_mul_fn
.
_dispatcher
.
cache
,
numba
.
core
.
caching
.
NullCache
numba_mul_fn
.
_dispatcher
.
cache
,
numba
.
core
.
caching
.
NullCache
)
)
with
config
.
change_flags
(
numba__cache
=
False
):
with
config
.
change_flags
(
numba__cache
=
False
):
aesara_numba_fn
=
function
([
x
],
x
*
2
,
mode
=
numba_mode
)
aesara_numba_fn
=
function
([
x
],
x
*
2
,
mode
=
numba_mode
)
numba_mul_fn
=
aesara_numba_fn
.
fn
.
jit_fn
.
py_func
.
__globals__
[
"mul"
]
numba_mul_fn
=
aesara_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"mul"
]
assert
isinstance
(
numba_mul_fn
.
_dispatcher
.
cache
,
numba
.
core
.
caching
.
NullCache
)
assert
isinstance
(
numba_mul_fn
.
_dispatcher
.
cache
,
numba
.
core
.
caching
.
NullCache
)
tests/link/test_numba_performance.py
浏览文件 @
c2092862
...
@@ -52,11 +52,11 @@ def test_careduce_performance(careduce_fn, numpy_fn, axis, inputs, input_vals):
...
@@ -52,11 +52,11 @@ def test_careduce_performance(careduce_fn, numpy_fn, axis, inputs, input_vals):
assert
np
.
array_equal
(
numba_res
,
numpy_res
)
assert
np
.
array_equal
(
numba_res
,
numpy_res
)
# FYI: To test the Numba JITed function directly, use `aesara_numba_fn.
fn
.jit_fn`
# FYI: To test the Numba JITed function directly, use `aesara_numba_fn.
vm
.jit_fn`
numpy_timer
=
timeit
.
Timer
(
"numpy_fn(*input_vals)"
,
"pass"
,
globals
=
locals
())
numpy_timer
=
timeit
.
Timer
(
"numpy_fn(*input_vals)"
,
"pass"
,
globals
=
locals
())
numba_timer
=
timeit
.
Timer
(
numba_timer
=
timeit
.
Timer
(
"aesara_numba_fn.
fn
.jit_fn(*input_vals)"
,
"pass"
,
globals
=
locals
()
"aesara_numba_fn.
vm
.jit_fn(*input_vals)"
,
"pass"
,
globals
=
locals
()
)
)
# c_timer = timeit.Timer("aesara_c_fn(*input_vals)", "pass", globals=locals())
# c_timer = timeit.Timer("aesara_c_fn(*input_vals)", "pass", globals=locals())
...
...
tests/link/test_vm.py
浏览文件 @
c2092862
...
@@ -86,7 +86,7 @@ def test_use_c_thunks():
...
@@ -86,7 +86,7 @@ def test_use_c_thunks():
),
),
)
)
assert
np
.
array_equal
(
a
*
b
,
f
(
a
,
b
))
assert
np
.
array_equal
(
a
*
b
,
f
(
a
,
b
))
assert
any
(
hasattr
(
t
,
"cthunk"
)
for
t
in
f
.
fn
.
thunks
)
==
use_c_thunks
assert
any
(
hasattr
(
t
,
"cthunk"
)
for
t
in
f
.
vm
.
thunks
)
==
use_c_thunks
@pytest.mark.skipif
(
@pytest.mark.skipif
(
...
@@ -215,9 +215,9 @@ def test_partial_function(linker):
...
@@ -215,9 +215,9 @@ def test_partial_function(linker):
if
linker
==
"cvm"
:
if
linker
==
"cvm"
:
from
aesara.link.c.cvm
import
CVM
from
aesara.link.c.cvm
import
CVM
assert
isinstance
(
f
.
fn
,
CVM
)
assert
isinstance
(
f
.
vm
,
CVM
)
else
:
else
:
assert
isinstance
(
f
.
fn
,
Stack
)
assert
isinstance
(
f
.
vm
,
Stack
)
assert
f
(
3
,
output_subset
=
[
0
,
1
,
2
])
==
f
(
3
)
assert
f
(
3
,
output_subset
=
[
0
,
1
,
2
])
==
f
(
3
)
assert
f
(
4
,
output_subset
=
[
0
,
2
])
==
[
f
(
4
)[
0
],
f
(
4
)[
2
]]
assert
f
(
4
,
output_subset
=
[
0
,
2
])
==
[
f
(
4
)[
0
],
f
(
4
)[
2
]]
...
@@ -277,17 +277,17 @@ def test_allow_gc_cvm():
...
@@ -277,17 +277,17 @@ def test_allow_gc_cvm():
f
([
1
])
f
([
1
])
n
=
list
(
f
.
maker
.
fgraph
.
apply_nodes
)[
0
]
.
outputs
[
0
]
n
=
list
(
f
.
maker
.
fgraph
.
apply_nodes
)[
0
]
.
outputs
[
0
]
assert
f
.
fn
.
storage_map
[
n
][
0
]
is
None
assert
f
.
vm
.
storage_map
[
n
][
0
]
is
None
assert
f
.
fn
.
allow_gc
is
True
assert
f
.
vm
.
allow_gc
is
True
f
.
fn
.
allow_gc
=
False
f
.
vm
.
allow_gc
=
False
assert
f
.
fn
.
allow_gc
is
False
assert
f
.
vm
.
allow_gc
is
False
f
([
1
])
f
([
1
])
assert
f
.
fn
.
storage_map
[
n
][
0
]
is
not
None
assert
f
.
vm
.
storage_map
[
n
][
0
]
is
not
None
f
.
fn
.
allow_gc
=
True
f
.
vm
.
allow_gc
=
True
assert
f
.
fn
.
allow_gc
is
True
assert
f
.
vm
.
allow_gc
is
True
f
([
1
])
f
([
1
])
assert
f
.
fn
.
storage_map
[
n
][
0
]
is
None
assert
f
.
vm
.
storage_map
[
n
][
0
]
is
None
class
RunOnce
(
Op
):
class
RunOnce
(
Op
):
...
@@ -334,7 +334,7 @@ def test_reallocation():
...
@@ -334,7 +334,7 @@ def test_reallocation():
f
=
function
([
x
,
y
],
z
,
name
=
"test_reduce_memory"
,
mode
=
m
)
f
=
function
([
x
,
y
],
z
,
name
=
"test_reduce_memory"
,
mode
=
m
)
output
=
f
(
1
,
2
)
output
=
f
(
1
,
2
)
assert
output
assert
output
storage_map
=
f
.
fn
.
storage_map
storage_map
=
f
.
vm
.
storage_map
def
check_storage
(
storage_map
):
def
check_storage
(
storage_map
):
for
i
in
storage_map
:
for
i
in
storage_map
:
...
@@ -365,8 +365,8 @@ def test_no_recycling():
...
@@ -365,8 +365,8 @@ def test_no_recycling():
mode
=
Mode
(
optimizer
=
"fast_compile"
,
linker
=
lnk
)
mode
=
Mode
(
optimizer
=
"fast_compile"
,
linker
=
lnk
)
f
=
function
([
x
],
x
+
1
,
mode
=
mode
)
f
=
function
([
x
],
x
+
1
,
mode
=
mode
)
f2
=
function
([
x
],
(
x
+
1
)
*
2
,
mode
=
mode
)
f2
=
function
([
x
],
(
x
+
1
)
*
2
,
mode
=
mode
)
m1
=
f
.
fn
.
thunks
[
0
]
.
thunk
.
module
m1
=
f
.
vm
.
thunks
[
0
]
.
thunk
.
module
m2
=
f2
.
fn
.
thunks
[
0
]
.
thunk
.
module
m2
=
f2
.
vm
.
thunks
[
0
]
.
thunk
.
module
assert
m1
is
m2
assert
m1
is
m2
...
@@ -381,7 +381,7 @@ def test_VMLinker_make_vm_cvm():
...
@@ -381,7 +381,7 @@ def test_VMLinker_make_vm_cvm():
linker
=
VMLinker
(
allow_gc
=
False
,
use_cloop
=
True
)
linker
=
VMLinker
(
allow_gc
=
False
,
use_cloop
=
True
)
f
=
function
([
a
],
a
,
mode
=
Mode
(
optimizer
=
None
,
linker
=
linker
))
f
=
function
([
a
],
a
,
mode
=
Mode
(
optimizer
=
None
,
linker
=
linker
))
assert
isinstance
(
f
.
fn
,
CVM
)
assert
isinstance
(
f
.
vm
,
CVM
)
def
test_VMLinker_make_vm_no_cvm
():
def
test_VMLinker_make_vm_no_cvm
():
...
@@ -405,7 +405,7 @@ def test_VMLinker_make_vm_no_cvm():
...
@@ -405,7 +405,7 @@ def test_VMLinker_make_vm_no_cvm():
import
aesara.link.c.cvm
import
aesara.link.c.cvm
f
=
function
([
a
],
a
,
mode
=
Mode
(
optimizer
=
None
,
linker
=
linker
))
f
=
function
([
a
],
a
,
mode
=
Mode
(
optimizer
=
None
,
linker
=
linker
))
assert
isinstance
(
f
.
fn
,
Loop
)
assert
isinstance
(
f
.
vm
,
Loop
)
def
test_VMLinker_exception
():
def
test_VMLinker_exception
():
...
...
tests/sandbox/test_rng_mrg.py
浏览文件 @
c2092862
...
@@ -916,7 +916,7 @@ def test_multMatVect():
...
@@ -916,7 +916,7 @@ def test_multMatVect():
r_a1
=
rng_mrg
.
matVecModM
(
A1
,
s1
,
m1
)
r_a1
=
rng_mrg
.
matVecModM
(
A1
,
s1
,
m1
)
r_a2
=
rng_mrg
.
matVecModM
(
A2
,
s2
,
m2
)
r_a2
=
rng_mrg
.
matVecModM
(
A2
,
s2
,
m2
)
f0
.
fn
()
f0
.
vm
()
r_b
=
f0
.
output_storage
[
0
]
.
value
r_b
=
f0
.
output_storage
[
0
]
.
value
assert
np
.
allclose
(
r_a1
,
r_b
[:
3
])
assert
np
.
allclose
(
r_a1
,
r_b
[:
3
])
...
...
tests/scan/test_basic.py
浏览文件 @
c2092862
...
@@ -2702,8 +2702,8 @@ def test_profile_info():
...
@@ -2702,8 +2702,8 @@ def test_profile_info():
assert
profile
.
callcount
==
0
assert
profile
.
callcount
==
0
assert
profile
.
nbsteps
==
0
assert
profile
.
nbsteps
==
0
assert
profile
.
call_time
==
0.0
assert
profile
.
call_time
==
0.0
assert
fn
.
fn
.
call_times
==
[
0.0
]
assert
fn
.
vm
.
call_times
==
[
0.0
]
assert
fn
.
fn
.
call_counts
==
[
0
]
assert
fn
.
vm
.
call_counts
==
[
0
]
z_fn
=
function
([],
z
)
z_fn
=
function
([],
z
)
...
@@ -2716,8 +2716,8 @@ def test_profile_info():
...
@@ -2716,8 +2716,8 @@ def test_profile_info():
# Confirm that `VM.update_profile` was called
# Confirm that `VM.update_profile` was called
assert
profile
.
apply_time
assert
profile
.
apply_time
assert
fn
.
fn
.
call_times
==
[
0.0
]
assert
fn
.
vm
.
call_times
==
[
0.0
]
assert
fn
.
fn
.
call_counts
==
[
0
]
assert
fn
.
vm
.
call_counts
==
[
0
]
class
TestExamples
:
class
TestExamples
:
...
...
tests/tensor/nnet/test_conv.py
浏览文件 @
c2092862
...
@@ -616,7 +616,7 @@ class TestConv2D(utt.InferShapeTester):
...
@@ -616,7 +616,7 @@ class TestConv2D(utt.InferShapeTester):
)
)
aesara_conv
=
aesara
.
function
([],
output
,
mode
=
mode
)
aesara_conv
=
aesara
.
function
([],
output
,
mode
=
mode
)
t1
=
time
.
time
()
t1
=
time
.
time
()
aesara_conv
.
fn
(
n_calls
=
n_calls
)
aesara_conv
.
vm
(
n_calls
=
n_calls
)
t2
=
time
.
time
()
t2
=
time
.
time
()
print
(
t2
-
t1
,
end
=
" "
)
print
(
t2
-
t1
,
end
=
" "
)
print
()
print
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论