Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
da6ea2fb
提交
da6ea2fb
authored
10月 18, 2020
作者:
Brandon T. Willard
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename get_debug_values to get_test_values
上级
0ea10588
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
16 行增加
和
16 行删除
+16
-16
test_op.py
tests/gof/test_op.py
+12
-12
op.py
theano/gof/op.py
+1
-1
gradient.py
theano/gradient.py
+3
-3
没有找到文件。
tests/gof/test_op.py
浏览文件 @
da6ea2fb
...
@@ -288,23 +288,23 @@ def test_test_value_op():
...
@@ -288,23 +288,23 @@ def test_test_value_op():
@change_flags
(
compute_test_value
=
"off"
)
@change_flags
(
compute_test_value
=
"off"
)
def
test_get_
debug
_values_no_debugger
():
def
test_get_
test
_values_no_debugger
():
"""Tests that `get_
debug
_values` returns `[]` when debugger is off."""
"""Tests that `get_
test
_values` returns `[]` when debugger is off."""
x
=
tt
.
vector
()
x
=
tt
.
vector
()
assert
op
.
get_
debug
_values
(
x
)
==
[]
assert
op
.
get_
test
_values
(
x
)
==
[]
@change_flags
(
compute_test_value
=
"ignore"
)
@change_flags
(
compute_test_value
=
"ignore"
)
def
test_get_
det_debug
_values_ignore
():
def
test_get_
test
_values_ignore
():
"""Tests that `get_
debug
_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
"""Tests that `get_
test
_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
x
=
tt
.
vector
()
x
=
tt
.
vector
()
assert
op
.
get_
debug
_values
(
x
)
==
[]
assert
op
.
get_
test
_values
(
x
)
==
[]
def
test_get_
debug
_values_success
():
def
test_get_
test
_values_success
():
"""Tests that `get_
debug_value
` returns values when available (and the debugger is on)."""
"""Tests that `get_
test_values
` returns values when available (and the debugger is on)."""
for
mode
in
[
"ignore"
,
"warn"
,
"raise"
]:
for
mode
in
[
"ignore"
,
"warn"
,
"raise"
]:
with
change_flags
(
compute_test_value
=
mode
):
with
change_flags
(
compute_test_value
=
mode
):
...
@@ -314,7 +314,7 @@ def test_get_debug_values_success():
...
@@ -314,7 +314,7 @@ def test_get_debug_values_success():
iters
=
0
iters
=
0
for
x_val
,
y_val
in
op
.
get_
debug
_values
(
x
,
y
):
for
x_val
,
y_val
in
op
.
get_
test
_values
(
x
,
y
):
assert
x_val
.
shape
==
(
4
,)
assert
x_val
.
shape
==
(
4
,)
assert
y_val
.
shape
==
(
5
,
5
)
assert
y_val
.
shape
==
(
5
,
5
)
...
@@ -325,9 +325,9 @@ def test_get_debug_values_success():
...
@@ -325,9 +325,9 @@ def test_get_debug_values_success():
@change_flags
(
compute_test_value
=
"raise"
)
@change_flags
(
compute_test_value
=
"raise"
)
def
test_get_
debug
_values_exc
():
def
test_get_
test
_values_exc
():
"""Tests that `get_
debug_value
` raises an exception when debugger is set to raise and a value is missing."""
"""Tests that `get_
test_values
` raises an exception when debugger is set to raise and a value is missing."""
with
pytest
.
raises
(
AttributeError
):
with
pytest
.
raises
(
AttributeError
):
x
=
tt
.
vector
()
x
=
tt
.
vector
()
assert
op
.
get_
debug
_values
(
x
)
==
[]
assert
op
.
get_
test
_values
(
x
)
==
[]
theano/gof/op.py
浏览文件 @
da6ea2fb
...
@@ -1080,7 +1080,7 @@ def missing_test_message(msg):
...
@@ -1080,7 +1080,7 @@ def missing_test_message(msg):
assert
action
in
[
"ignore"
,
"off"
]
assert
action
in
[
"ignore"
,
"off"
]
def
get_
debug
_values
(
*
args
):
def
get_
test
_values
(
*
args
):
"""
"""
Intended use:
Intended use:
...
...
theano/gradient.py
浏览文件 @
da6ea2fb
...
@@ -15,7 +15,7 @@ from theano import gof
...
@@ -15,7 +15,7 @@ from theano import gof
from
theano.gof
import
utils
,
Variable
from
theano.gof
import
utils
,
Variable
from
theano.gof.null_type
import
NullType
,
null_type
from
theano.gof.null_type
import
NullType
,
null_type
from
theano.gof.op
import
get_
debug
_values
from
theano.gof.op
import
get_
test
_values
from
theano.compile
import
ViewOp
,
FAST_RUN
,
DebugMode
,
get_mode
from
theano.compile
import
ViewOp
,
FAST_RUN
,
DebugMode
,
get_mode
__authors__
=
"James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow"
__authors__
=
"James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow"
...
@@ -1217,7 +1217,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
...
@@ -1217,7 +1217,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
continue
continue
if
isinstance
(
new_output_grad
.
type
,
DisconnectedType
):
if
isinstance
(
new_output_grad
.
type
,
DisconnectedType
):
continue
continue
for
orig_output_v
,
new_output_grad_v
in
get_
debug
_values
(
*
packed
):
for
orig_output_v
,
new_output_grad_v
in
get_
test
_values
(
*
packed
):
o_shape
=
orig_output_v
.
shape
o_shape
=
orig_output_v
.
shape
g_shape
=
new_output_grad_v
.
shape
g_shape
=
new_output_grad_v
.
shape
if
o_shape
!=
g_shape
:
if
o_shape
!=
g_shape
:
...
@@ -1310,7 +1310,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
...
@@ -1310,7 +1310,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
# has the right shape
# has the right shape
if
hasattr
(
term
,
"shape"
):
if
hasattr
(
term
,
"shape"
):
orig_ipt
=
inputs
[
i
]
orig_ipt
=
inputs
[
i
]
for
orig_ipt_v
,
term_v
in
get_
debug
_values
(
orig_ipt
,
term
):
for
orig_ipt_v
,
term_v
in
get_
test
_values
(
orig_ipt
,
term
):
i_shape
=
orig_ipt_v
.
shape
i_shape
=
orig_ipt_v
.
shape
t_shape
=
term_v
.
shape
t_shape
=
term_v
.
shape
if
i_shape
!=
t_shape
:
if
i_shape
!=
t_shape
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论