Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
74f80840
提交
74f80840
authored
1月 21, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
1月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move error and profiler handling to the thunk in Scan's Cython implementation
上级
c997333d
显示空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
115 行增加
和
37 行删除
+115
-37
scan_perform.c
aesara/scan/c_code/scan_perform.c
+0
-0
op.py
aesara/scan/op.py
+31
-4
scan_perform.pyx
aesara/scan/scan_perform.pyx
+11
-27
scan_perform_ext.py
aesara/scan/scan_perform_ext.py
+1
-1
utils.py
aesara/scan/utils.py
+6
-2
test_basic.py
tests/scan/test_basic.py
+66
-3
没有找到文件。
aesara/scan/c_code/scan_perform.c
浏览文件 @
74f80840
This source diff could not be displayed because it is too large. You can
view the blob
instead.
aesara/scan/op.py
浏览文件 @
74f80840
...
@@ -43,7 +43,6 @@ relies on the following elements to work properly :
...
@@ -43,7 +43,6 @@ relies on the following elements to work properly :
"""
"""
import
dataclasses
import
dataclasses
import
logging
import
logging
import
time
import
time
...
@@ -1401,11 +1400,23 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1401,11 +1400,23 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
getattr
(
out
,
"ndim"
,
None
)
for
out
in
node
.
outputs
getattr
(
out
,
"ndim"
,
None
)
for
out
in
node
.
outputs
)
)
from
aesara.scan.utils
import
InnerFunctionError
# TODO: Extract `Capsule` object and use that
# c_thunk = getattr(self.fn.fn.thunks[0], "cthunk", None)
# if len(self.fn.fn.thunks) == 1 and c_thunk:
# thunk_capsule = c_thunk.cthunk
# # We need to perform the following after calling
# # the thunk function:
# # for o in node.outputs:
# # compute_map[o][0] = True
def
p
(
node
,
inputs
,
outputs
):
def
p
(
node
,
inputs
,
outputs
):
t0_call
=
time
.
perf_counter
()
t0_call
=
time
.
perf_counter
()
t_fn
=
scan_perform_ext
.
perform
(
try
:
t_fn
,
n_steps
=
scan_perform_ext
.
perform
(
self
.
n_shared_outs
,
self
.
n_shared_outs
,
self
.
n_mit_mot_outs
,
self
.
n_mit_mot_outs
,
self
.
n_seqs
,
self
.
n_seqs
,
...
@@ -1427,13 +1438,29 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1427,13 +1438,29 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_output_storage
,
inner_output_storage
,
getattr
(
self
.
fn
.
fn
,
"need_update_inputs"
,
True
),
getattr
(
self
.
fn
.
fn
,
"need_update_inputs"
,
True
),
inner_input_needs_update
,
inner_input_needs_update
,
self
.
fn
,
cython_destroy_map
,
cython_destroy_map
,
inputs
,
inputs
,
outputs
,
outputs
,
outer_output_dtypes
,
outer_output_dtypes
,
outer_output_ndims
,
outer_output_ndims
,
self
.
fn
.
fn
,
)
)
except
InnerFunctionError
as
exc
:
exc_type
=
type
(
exc
.
args
[
0
])
exc_value
=
exc
.
args
[
0
]
exc_trace
=
exc
.
args
[
1
]
if
hasattr
(
self
.
fn
.
fn
,
"position_of_error"
)
and
hasattr
(
self
.
fn
.
fn
,
"thunks"
):
raise_with_op
(
self
.
fn
.
maker
.
fgraph
,
self
.
fn
.
fn
.
nodes
[
self
.
fn
.
fn
.
position_of_error
],
self
.
fn
.
fn
.
thunks
[
self
.
fn
.
fn
.
position_of_error
],
exc_info
=
(
exc_type
,
exc_value
,
exc_trace
),
)
else
:
raise
exc_value
.
with_traceback
(
exc_trace
)
t_call
=
time
.
perf_counter
()
-
t0_call
t_call
=
time
.
perf_counter
()
-
t0_call
...
@@ -1442,7 +1469,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1442,7 +1469,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if
type
(
profile
)
is
not
bool
and
profile
:
if
type
(
profile
)
is
not
bool
and
profile
:
profile
.
vm_call_time
+=
t_fn
profile
.
vm_call_time
+=
t_fn
profile
.
callcount
+=
1
profile
.
callcount
+=
1
profile
.
nbsteps
+=
outputs
[
0
]
profile
.
nbsteps
+=
n_steps
profile
.
call_time
+=
t_call
profile
.
call_time
+=
t_call
if
hasattr
(
self
.
fn
.
fn
,
"update_profile"
):
if
hasattr
(
self
.
fn
.
fn
,
"update_profile"
):
self
.
fn
.
fn
.
update_profile
(
profile
)
self
.
fn
.
fn
.
update_profile
(
profile
)
...
...
aesara/scan/scan_perform.pyx
浏览文件 @
74f80840
...
@@ -53,12 +53,13 @@ cimport numpy
...
@@ -53,12 +53,13 @@ cimport numpy
import copy
import copy
import time
import time
import sys
from aesara.
link.utils import raise_with_op
from aesara.
scan.utils import InnerFunctionError
def get_version():
def get_version():
return 0.3
0
2
return 0.3
1
2
@cython.boundscheck(False)
@cython.boundscheck(False)
def perform(
def perform(
...
@@ -83,13 +84,13 @@ def perform(
...
@@ -83,13 +84,13 @@ def perform(
list inner_output_storage,
list inner_output_storage,
bint need_update_inputs,
bint need_update_inputs,
tuple inner_input_needs_update,
tuple inner_input_needs_update,
fnct,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
list outer_inputs,
list outer_inputs,
list outer_outputs,
list outer_outputs,
tuple outer_output_dtypes,
tuple outer_output_dtypes,
tuple outer_output_ndims,
tuple outer_output_ndims,
):
fn,
) -> (float, int):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -160,6 +161,8 @@ def perform(
...
@@ -160,6 +161,8 @@ def perform(
The dtypes for each outer output.
The dtypes for each outer output.
outer_output_ndims
outer_output_ndims
The number of dimensions for each outer output.
The number of dimensions for each outer output.
fn
The inner function thunk.
"""
"""
# 1. Unzip the number of steps and sequences. If number of steps is
# 1. Unzip the number of steps and sequences. If number of steps is
...
@@ -258,7 +261,7 @@ def perform(
...
@@ -258,7 +261,7 @@ def perform(
outer_outputs[idx][0] = numpy.empty((0,) * outer_output_ndims[idx], dtype=outer_output_dtypes[idx])
outer_outputs[idx][0] = numpy.empty((0,) * outer_output_ndims[idx], dtype=outer_output_dtypes[idx])
else:
else:
outer_outputs[idx][0] = None
outer_outputs[idx][0] = None
return
return
0.0, 0
for idx in range(n_outs + n_nit_sot):
for idx in range(n_outs + n_nit_sot):
pos[idx] = -mintaps[idx] % store_steps[idx]
pos[idx] = -mintaps[idx] % store_steps[idx]
...
@@ -282,8 +285,6 @@ def perform(
...
@@ -282,8 +285,6 @@ def perform(
for idx in range(len(other_args)):
for idx in range(len(other_args)):
inner_input_storage[<unsigned int>(idx+offset)][0] = other_args[idx]
inner_input_storage[<unsigned int>(idx+offset)][0] = other_args[idx]
fn = fnct.fn
i = 0
i = 0
cond = 1
cond = 1
############## THE MAIN LOOP #########################
############## THE MAIN LOOP #########################
...
@@ -398,25 +399,8 @@ def perform(
...
@@ -398,25 +399,8 @@ def perform(
try:
try:
fn()
fn()
except Exception:
except Exception as exc:
if hasattr(fn, 'position_of_error'):
raise InnerFunctionError(exc, sys.exc_info()[-1])
# this is a new vm-provided function
# the C VM needs this because the exception manipulation
# done by raise_with_op is not implemented in C.
if hasattr(fn, 'thunks'):
# For the CVM
raise_with_op(fnct.maker.fgraph,
fn.nodes[fn.position_of_error],
fn.thunks[fn.position_of_error])
else:
# For the c linker
# We don't have access from python to all the
# temps values So for now, we just don't print
# the extra shapes/strides info
raise_with_op(fnct.maker.fgraph, fn.nodes[fn.position_of_error])
else:
# old-style linkers raise their own exceptions
raise
dt_fn = time.time() - t0_fn
dt_fn = time.time() - t0_fn
t_fn += dt_fn
t_fn += dt_fn
...
@@ -625,4 +609,4 @@ def perform(
...
@@ -625,4 +609,4 @@ def perform(
for s in inner_output_storage:
for s in inner_output_storage:
s[0] = None
s[0] = None
return t_fn
return t_fn
, i
aesara/scan/scan_perform_ext.py
浏览文件 @
74f80840
...
@@ -21,7 +21,7 @@ if not config.cxx:
...
@@ -21,7 +21,7 @@ if not config.cxx:
_logger
=
logging
.
getLogger
(
"aesara.scan.scan_perform"
)
_logger
=
logging
.
getLogger
(
"aesara.scan.scan_perform"
)
version
=
0.3
0
2
# must match constant returned in function get_version()
version
=
0.3
1
2
# must match constant returned in function get_version()
need_reload
=
False
need_reload
=
False
...
...
aesara/scan/utils.py
浏览文件 @
74f80840
...
@@ -35,6 +35,10 @@ if TYPE_CHECKING:
...
@@ -35,6 +35,10 @@ if TYPE_CHECKING:
_logger
=
logging
.
getLogger
(
"aesara.scan.utils"
)
_logger
=
logging
.
getLogger
(
"aesara.scan.utils"
)
class
InnerFunctionError
(
Exception
):
"""An exception indicating that an error occurred in `Scan`'s inner function."""
def
safe_new
(
def
safe_new
(
x
:
Variable
,
tag
:
str
=
""
,
dtype
:
Optional
[
Union
[
str
,
np
.
dtype
]]
=
None
x
:
Variable
,
tag
:
str
=
""
,
dtype
:
Optional
[
Union
[
str
,
np
.
dtype
]]
=
None
)
->
Variable
:
)
->
Variable
:
...
@@ -126,8 +130,8 @@ class until:
...
@@ -126,8 +130,8 @@ class until:
class
ScanProfileStats
(
ProfileStats
):
class
ScanProfileStats
(
ProfileStats
):
show_sum
=
False
show_sum
=
False
callcount
=
0
.0
callcount
=
0
nbsteps
=
0
.0
nbsteps
=
0
call_time
=
0.0
call_time
=
0.0
def
__init__
(
self
,
atexit_print
=
True
,
name
=
None
,
**
kwargs
):
def
__init__
(
self
,
atexit_print
=
True
,
name
=
None
,
**
kwargs
):
...
...
tests/scan/test_basic.py
浏览文件 @
74f80840
...
@@ -4635,7 +4635,10 @@ class TestScan:
...
@@ -4635,7 +4635,10 @@ class TestScan:
@pytest.mark.skipif
(
@pytest.mark.skipif
(
not
config
.
cxx
,
reason
=
"G++ not available, so we need to skip this test."
not
config
.
cxx
,
reason
=
"G++ not available, so we need to skip this test."
)
)
def
test_cvm_exception_handling
():
@pytest.mark.parametrize
(
"mode"
,
[
Mode
(
linker
=
"c|py"
,
optimizer
=
None
),
Mode
(
linker
=
"cvm"
,
optimizer
=
None
)]
)
def
test_cvm_exception_handling
(
mode
):
class
MyOp
(
Op
):
class
MyOp
(
Op
):
def
make_node
(
self
,
input
):
def
make_node
(
self
,
input
):
return
Apply
(
self
,
[
input
],
[
vector
()])
return
Apply
(
self
,
[
input
],
[
vector
()])
...
@@ -4643,13 +4646,18 @@ def test_cvm_exception_handling():
...
@@ -4643,13 +4646,18 @@ def test_cvm_exception_handling():
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
raise
Exception
(
"blah"
)
raise
Exception
(
"blah"
)
# def c_code(self, node, name, inputs, outputs, sub):
# fail = sub["fail"]
# return f"""
# PyErr_SetString(PyExc_Exception, "blah");
# {fail};
# """
myop
=
MyOp
()
myop
=
MyOp
()
def
scan_fn
():
def
scan_fn
():
return
myop
(
at
.
as_tensor
(
1
))
return
myop
(
at
.
as_tensor
(
1
))
mode
=
Mode
(
optimizer
=
None
,
linker
=
"cvm"
)
res
,
_
=
scan
(
scan_fn
,
n_steps
=
4
,
mode
=
mode
)
res
,
_
=
scan
(
scan_fn
,
n_steps
=
4
,
mode
=
mode
)
res_fn
=
function
([],
res
,
mode
=
mode
)
res_fn
=
function
([],
res
,
mode
=
mode
)
...
@@ -5198,3 +5206,58 @@ def test_inner_get_vector_length():
...
@@ -5198,3 +5206,58 @@ def test_inner_get_vector_length():
res_fn
=
function
([],
res
.
shape
)
res_fn
=
function
([],
res
.
shape
)
assert
np
.
array_equal
(
res_fn
(),
(
10
,
3
))
assert
np
.
array_equal
(
res_fn
(),
(
10
,
3
))
@config.change_flags
(
mode
=
Mode
(
"cvm"
,
None
))
def
test_profile_info
():
from
aesara.scan.utils
import
ScanProfileStats
z
,
updates
=
scan
(
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
at
.
arange
(
10
)],
profile
=
True
)
assert
isinstance
(
z
.
owner
.
op
,
Scan
)
fn
=
z
.
owner
.
op
.
fn
assert
isinstance
(
fn
.
profile
,
ScanProfileStats
)
assert
fn
.
profile
.
name
==
"scan_fn"
# Set the `ScanProfileStats` name
z
,
updates
=
scan
(
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
at
.
arange
(
10
)],
profile
=
"profile_name"
)
assert
isinstance
(
z
.
owner
.
op
,
Scan
)
fn
=
z
.
owner
.
op
.
fn
assert
isinstance
(
fn
.
profile
,
ScanProfileStats
)
assert
fn
.
profile
.
name
==
"profile_name"
# Use an existing profile object
profile
=
fn
.
profile
z
,
updates
=
scan
(
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
at
.
arange
(
10
)],
profile
=
profile
)
assert
isinstance
(
z
.
owner
.
op
,
Scan
)
fn
=
z
.
owner
.
op
.
fn
assert
fn
.
profile
is
profile
assert
not
profile
.
apply_time
assert
profile
.
callcount
==
0
assert
profile
.
nbsteps
==
0
assert
profile
.
call_time
==
0.0
assert
fn
.
fn
.
call_times
==
[
0.0
]
assert
fn
.
fn
.
call_counts
==
[
0
]
z_fn
=
function
([],
z
)
_
=
z_fn
()
# assert profile.vm_call_time > 0
assert
profile
.
callcount
==
1
assert
profile
.
nbsteps
==
10
assert
profile
.
call_time
>
0
# Confirm that `VM.update_profile` was called
assert
profile
.
apply_time
assert
fn
.
fn
.
call_times
==
[
0.0
]
assert
fn
.
fn
.
call_counts
==
[
0
]
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论