Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
14c7373e
提交
14c7373e
authored
5月 19, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
5月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix attribute error in Cython version of Scan's exception handling
上级
17ba075e
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
38 行增加
和
23 行删除
+38
-23
scan_perform.c
aesara/scan/c_code/scan_perform.c
+0
-0
scan_perform.pyx
aesara/scan/scan_perform.pyx
+9
-21
scan_perform_ext.py
aesara/scan/scan_perform_ext.py
+1
-1
test_basic.py
tests/scan/test_basic.py
+28
-1
没有找到文件。
aesara/scan/c_code/scan_perform.c
浏览文件 @
14c7373e
This source diff could not be displayed because it is too large. You can
view the blob
instead.
aesara/scan/scan_perform.pyx
浏览文件 @
14c7373e
...
...
@@ -64,7 +64,7 @@ from aesara.link.utils import raise_with_op
def get_version():
return 0.29
8
return 0.29
9
@cython.boundscheck(False)
def perform(
...
...
@@ -153,9 +153,7 @@ def perform(
starting point of implementing this function in C ( we need to take
all the code around the call of this function and put in C inside
that code)
fnct: python object
Only used to attach some timings for the profile mode ( can be
skiped if we don't care about Aesara's profile mode)
fnct: Function
destroy_map
Array of boolean saying if an output is computed inplace
args: list of ndarrays (and random states)
...
...
@@ -404,7 +402,7 @@ def perform(
# done by raise_with_op is not implemented in C.
if hasattr(fn, 'thunks'):
# For the CVM
raise_with_op(fn.maker.fgraph,
raise_with_op(fn
ct
.maker.fgraph,
fn.nodes[fn.position_of_error],
fn.thunks[fn.position_of_error])
else:
...
...
@@ -412,7 +410,7 @@ def perform(
# We don't have access from python to all the
# temps values So for now, we just don't print
# the extra shapes/strides info
raise_with_op(fn.maker.fgraph, fn.nodes[fn.position_of_error])
raise_with_op(fn
ct
.maker.fgraph, fn.nodes[fn.position_of_error])
else:
# old-style linkers raise their own exceptions
raise
...
...
@@ -612,11 +610,11 @@ def perform(
# do not get applied
if i < n_steps:
# Cython can not handle negative indices ( because of a
# deri
ctive at the beginning of the function that says not
# to do boundschecks). The directive is used to make the
# code faster, so this workaround is better then removing
# the directive.
# Cython can not handle negative indices ( because of a
# dire
ctive at the beginning of the function that says not
# to do boundschecks). The directive is used to make the
# code faster, so this workaround is better then removing
# the directive.
sh0 = outs[idx][0].shape[0]
outs[idx][0] = outs[idx][0][:sh0-(n_steps - i)]
...
...
@@ -639,15 +637,5 @@ def perform(
if hasattr(fn, 'update_profile'):
fn.update_profile(profile)
### Old Profile Mode
#if hasattr(fnct.maker.mode,'fct_call_time'):
# fnct.maker.mode.fct_call_time[fnct] += t_fn
# fnct.maker.mode.fct_call[fnct] += n_steps
#fnct.maker.mode.call_time += t_fn
#fnct.maker.mode.fn_time += t_fn
# DEBUG PRINT :
self.t_call = t_call
self.t_fn = t_fn
# print 'Cython > timing', t_call, t_fn, 'in percentage', 100.*t_fn/t_call
aesara/scan/scan_perform_ext.py
浏览文件 @
14c7373e
...
...
@@ -21,7 +21,7 @@ if not config.cxx:
_logger
=
logging
.
getLogger
(
"aesara.scan.scan_perform"
)
version
=
0.29
8
# must match constant returned in function get_version()
version
=
0.29
9
# must match constant returned in function get_version()
need_reload
=
False
...
...
tests/scan/test_basic.py
浏览文件 @
14c7373e
...
...
@@ -37,8 +37,9 @@ from aesara.gradient import (
hessian
,
jacobian
,
)
from
aesara.graph.basic
import
clone_replace
,
graph_inputs
from
aesara.graph.basic
import
Apply
,
clone_replace
,
graph_inputs
from
aesara.graph.fg
import
MissingInputError
from
aesara.graph.op
import
Op
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.scan.basic
import
scan
from
aesara.scan.op
import
Scan
...
...
@@ -4519,6 +4520,32 @@ class TestScan:
assert
detect_large_outputs
.
large_count
==
3
@pytest.mark.skipif
(
not
config
.
cxx
,
reason
=
"G++ not available, so we need to skip this test."
)
def
test_cvm_exception_handling
():
class
MyOp
(
Op
):
def
make_node
(
self
,
input
):
return
Apply
(
self
,
[
input
],
[
vector
()])
def
perform
(
self
,
node
,
inputs
,
outputs
):
raise
Exception
(
"blah"
)
myop
=
MyOp
()
def
scan_fn
():
return
myop
(
aet
.
as_tensor
(
1
))
mode
=
Mode
(
optimizer
=
None
,
linker
=
"cvm"
)
res
,
_
=
scan
(
scan_fn
,
n_steps
=
4
,
mode
=
mode
)
res_fn
=
function
([],
res
,
mode
=
mode
)
with
pytest
.
raises
(
Exception
,
match
=
"blah"
):
res_fn
()
@pytest.mark.skipif
(
not
config
.
cxx
,
reason
=
"G++ not available, so we need to skip this test."
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论