Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
434dd96e
提交
434dd96e
authored
11月 16, 2012
作者:
Ian Goodfellow
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
get rid of dangerous "TypeError = not constant" mechanism
上级
5237b952
隐藏空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
40 行增加
和
46 行删除
+40
-46
__init__.py
theano/__init__.py
+2
-1
gradient.py
theano/gradient.py
+1
-1
ops.py
theano/sandbox/linalg/ops.py
+1
-1
scan.py
theano/sandbox/scan.py
+1
-1
scan.py
theano/scan_module/scan.py
+1
-1
basic.py
theano/tensor/basic.py
+9
-16
sigm.py
theano/tensor/nnet/sigm.py
+1
-1
opt.py
theano/tensor/opt.py
+23
-23
test_basic.py
theano/tensor/tests/test_basic.py
+1
-1
没有找到文件。
theano/__init__.py
浏览文件 @
434dd96e
...
...
@@ -171,7 +171,8 @@ def get_constant_value(v):
If theano.sparse is also there, we will look over CSM op.
If `v` is not some view of constant data, then raise a TypeError.
If `v` is not some view of constant data, then raise a
tensor.basic.NotConstantError.
"""
if
hasattr
(
theano
,
'sparse'
)
and
isinstance
(
v
.
type
,
theano
.
sparse
.
SparseType
):
...
...
theano/gradient.py
浏览文件 @
434dd96e
...
...
@@ -1618,7 +1618,7 @@ def _is_zero(x):
try
:
constant_value
=
theano
.
get_constant_value
(
x
)
no_constant_value
=
False
except
Type
Error
:
except
theano
.
tensor
.
basic
.
NotConstant
Error
:
pass
if
no_constant_value
:
...
...
theano/sandbox/linalg/ops.py
浏览文件 @
434dd96e
...
...
@@ -215,7 +215,7 @@ def is_positive(v):
if
v
.
owner
and
v
.
owner
.
op
==
tensor
.
pow
:
try
:
exponent
=
tensor
.
get_constant_value
(
v
.
owner
.
inputs
[
1
])
except
Type
Error
:
except
tensor
.
basic
.
NotConstant
Error
:
return
False
if
0
==
exponent
%
2
:
return
True
...
...
theano/sandbox/scan.py
浏览文件 @
434dd96e
...
...
@@ -136,7 +136,7 @@ def scan(fn,
else
:
try
:
n_fixed_steps
=
opt
.
get_constant_value
(
n_steps
)
except
(
TypeError
,
AttributeError
)
:
except
tensor
.
basic
.
NotConstantError
:
n_fixed_steps
=
None
# Check n_steps is an int
...
...
theano/scan_module/scan.py
浏览文件 @
434dd96e
...
...
@@ -364,7 +364,7 @@ def scan(fn,
else
:
try
:
n_fixed_steps
=
opt
.
get_constant_value
(
n_steps
)
except
(
TypeError
,
AttributeError
)
:
except
tensor
.
basic
.
NotConstantError
:
n_fixed_steps
=
None
# Check n_steps is an int
...
...
theano/tensor/basic.py
浏览文件 @
434dd96e
...
...
@@ -463,17 +463,10 @@ def _allclose(a, b, rtol=None, atol=None):
return
numpy
.
allclose
(
a
,
b
,
atol
=
atol_
,
rtol
=
rtol_
)
class
NotConstantError
(
TypeError
):
class
NotConstantError
(
Exception
):
"""
Raised by get_constant_value if called on something that is
not constant.
For now it is a TypeError, to maintain the old interface
that get_constant_value should raise a TypeError in this
situation. However, this is unsafe because get_constant_value
could inadvertently raise a TypeError if it has a bug.
So we should eventually make NotConstantError derive
from Exception directly, and modify all code that uses
get_constant_value to catch this more specific exception.
"""
pass
...
...
@@ -2364,7 +2357,7 @@ class SpecifyShape(Op):
s
=
get_constant_value
(
node
.
inputs
[
1
][
dim
])
s
=
as_tensor_variable
(
s
)
new_shape
.
append
(
s
)
except
Type
Error
:
except
NotConstant
Error
:
new_shape
.
append
(
node
.
inputs
[
1
][
dim
])
assert
len
(
new_shape
)
==
len
(
xshape
)
...
...
@@ -2662,7 +2655,7 @@ def max(x, axis=None, keepdims=False):
try
:
const
=
get_constant_value
(
axis
)
out
=
CAReduce
(
scal
.
maximum
,
list
(
const
))(
x
)
except
Exception
:
except
NotConstantError
:
out
=
max_and_argmax
(
x
,
axis
)[
0
]
if
keepdims
:
...
...
@@ -3272,7 +3265,7 @@ class Alloc(gof.Op):
# if s is constant 1, then we're broadcastable in that dim
try
:
const_shp
=
get_constant_value
(
s
)
except
Type
Error
:
except
NotConstant
Error
:
const_shp
=
None
bcast
.
append
(
numpy
.
all
(
1
==
const_shp
))
otype
=
TensorType
(
dtype
=
v
.
dtype
,
broadcastable
=
bcast
)
...
...
@@ -3820,7 +3813,7 @@ def extract_constant(x):
'''
try
:
x
=
get_constant_value
(
x
)
except
Exception
:
except
NotConstantError
:
pass
if
(
isinstance
(
x
,
scal
.
ScalarVariable
)
or
isinstance
(
x
,
scal
.
sharedvar
.
ScalarSharedVariable
)):
...
...
@@ -5423,7 +5416,7 @@ class Join(Op):
# int
axis
=
int
(
get_constant_value
(
axis
))
except
Type
Error
:
except
NotConstant
Error
:
pass
if
isinstance
(
axis
,
int
):
# Basically, broadcastable -> length 1, but the
...
...
@@ -5792,7 +5785,7 @@ class Reshape(Op):
try
:
bcasts
[
index
]
=
(
hasattr
(
y
,
'get_constant_value'
)
and
y
.
get_constant_value
()
==
1
)
except
Type
Error
:
except
NotConstant
Error
:
pass
return
gof
.
Apply
(
self
,
[
x
,
shp
],
[
tensor
(
x
.
type
.
dtype
,
bcasts
)])
...
...
@@ -5868,7 +5861,7 @@ class Reshape(Op):
os_i
=
get_constant_value
(
node
.
inputs
[
1
][
i
])
.
item
()
if
os_i
==
-
1
:
os_i
=
default_os_i
except
Type
Error
:
except
NotConstant
Error
:
os_i
=
default_os_i
oshape
.
append
(
os_i
)
return
[
tuple
(
oshape
)]
...
...
@@ -6150,7 +6143,7 @@ class ARange(Op):
try
:
v
=
get_constant_value
(
var
)
return
numpy
.
all
(
v
==
value
)
except
Exception
:
except
NotConstantError
:
pass
return
False
...
...
theano/tensor/nnet/sigm.py
浏览文件 @
434dd96e
...
...
@@ -138,7 +138,7 @@ def _is_1(expr):
try
:
v
=
opt
.
get_constant_value
(
expr
)
return
numpy
.
allclose
(
v
,
1
)
except
Type
Error
:
except
tensor
.
NotConstant
Error
:
return
False
log1msigm_to_softplus
=
gof
.
PatternSub
(
...
...
theano/tensor/opt.py
浏览文件 @
434dd96e
...
...
@@ -33,7 +33,7 @@ from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer
)
from
theano.gof.opt
import
merge_optimizer
from
theano.gof
import
toolbox
,
DestroyHandler
from
basic
import
get_constant_value
,
ShapeError
from
basic
import
get_constant_value
,
ShapeError
,
NotConstantError
theano
.
configparser
.
AddConfigVar
(
'on_shape_error'
,
...
...
@@ -95,7 +95,7 @@ def scalarconsts_rest(inputs):
v
=
get_constant_value
(
i
)
consts
.
append
(
v
)
origconsts
.
append
(
i
)
except
Exception
:
except
NotConstantError
:
nonconsts
.
append
(
i
)
return
consts
,
origconsts
,
nonconsts
...
...
@@ -324,13 +324,13 @@ def local_0_dot_x(node):
try
:
if
get_constant_value
(
x
)
==
0
:
replace
=
True
except
Type
Error
:
except
NotConstant
Error
:
pass
try
:
if
get_constant_value
(
y
)
==
0
:
replace
=
True
except
Type
Error
:
except
NotConstant
Error
:
pass
if
replace
:
...
...
@@ -1179,7 +1179,7 @@ def local_subtensor_make_vector(node):
try
:
v
=
get_constant_value
(
idx
)
return
[
x
.
owner
.
inputs
[
v
]]
except
Exception
:
except
NotConstantError
:
pass
else
:
# it is a slice of ints and/or Variables
...
...
@@ -1321,7 +1321,7 @@ def local_remove_useless_assert(node):
#Should we raise an error here? How to be sure it
#is not catched?
cond
.
append
(
c
)
except
Type
Error
:
except
NotConstant
Error
:
cond
.
append
(
c
)
if
len
(
cond
)
==
0
:
...
...
@@ -1551,7 +1551,7 @@ def local_useless_subtensor(node):
length_pos
=
shape_of
[
node
.
inputs
[
0
]][
pos
]
try
:
length_pos_data
=
get_constant_value
(
length_pos
)
except
Type
Error
:
except
NotConstant
Error
:
pass
if
isinstance
(
idx
.
stop
,
int
):
...
...
@@ -2034,7 +2034,7 @@ def local_incsubtensor_of_allocs(node):
try
:
if
get_constant_value
(
y
)
==
0
:
replace
=
True
except
Type
Error
:
except
NotConstant
Error
:
pass
if
replace
:
...
...
@@ -2060,12 +2060,12 @@ def local_setsubtensor_of_allocs(node):
try
:
replace_x
=
get_constant_value
(
x
)
except
Type
Error
:
except
NotConstant
Error
:
pass
try
:
replace_y
=
get_constant_value
(
y
)
except
Type
Error
:
except
NotConstant
Error
:
pass
if
(
replace_x
==
replace_y
and
...
...
@@ -2260,7 +2260,7 @@ def local_mul_switch_sink(node):
fct
[
0
]
.
values_eq_approx
=
fct
[
0
]
.
type
.
values_eq_approx_remove_nan
return
fct
except
Type
Error
:
except
NotConstant
Error
:
pass
try
:
if
get_constant_value
(
switch
.
inputs
[
2
])
==
0.
:
...
...
@@ -2270,7 +2270,7 @@ def local_mul_switch_sink(node):
fct
[
0
]
.
values_eq_approx
=
fct
[
0
]
.
type
.
values_eq_approx_remove_nan
return
fct
except
Type
Error
:
except
NotConstant
Error
:
pass
return
False
...
...
@@ -2301,7 +2301,7 @@ def local_div_switch_sink(node):
fct
[
0
]
.
values_eq_approx
=
fct
[
0
]
.
type
.
values_eq_approx_remove_nan
return
fct
except
Type
Error
:
except
NotConstant
Error
:
pass
try
:
if
get_constant_value
(
switch
.
inputs
[
2
])
==
0.
:
...
...
@@ -2310,7 +2310,7 @@ def local_div_switch_sink(node):
fct
[
0
]
.
values_eq_approx
=
fct
[
0
]
.
type
.
values_eq_approx_remove_nan
return
fct
except
Type
Error
:
except
NotConstant
Error
:
pass
return
False
...
...
@@ -2703,7 +2703,7 @@ class Canonizer(gof.LocalOptimizer):
if
isinstance
(
v
,
Variable
):
try
:
return
get_constant_value
(
v
)
except
Type
Error
:
except
NotConstant
Error
:
return
None
else
:
return
v
...
...
@@ -3208,7 +3208,7 @@ def local_sum_alloc(node):
assert
val
.
size
==
1
val
=
val
.
reshape
(
1
)[
0
]
*
T
.
mul
(
*
shapes
)
return
[
T
.
cast
(
val
,
dtype
=
node
.
outputs
[
0
]
.
dtype
)]
except
Type
Error
:
except
NotConstant
Error
:
pass
else
:
try
:
...
...
@@ -3222,7 +3222,7 @@ def local_sum_alloc(node):
return
[
T
.
alloc
(
T
.
cast
(
val
,
dtype
=
node
.
outputs
[
0
]
.
dtype
),
*
[
shapes
[
i
]
for
i
in
xrange
(
len
(
shapes
))
if
i
not
in
node
.
op
.
axis
])]
except
Type
Error
:
except
NotConstant
Error
:
pass
...
...
@@ -3283,7 +3283,7 @@ def local_mul_zero(node):
for
i
in
node
.
inputs
:
try
:
value
=
get_constant_value
(
i
)
except
Type
Error
:
except
NotConstant
Error
:
continue
#print 'MUL by value', value, node.inputs
if
N
.
all
(
value
==
0
):
...
...
@@ -3521,7 +3521,7 @@ def local_add_specialize(node):
for
input
in
node
.
inputs
:
try
:
y
=
get_constant_value
(
input
)
except
Type
Error
:
except
NotConstant
Error
:
y
=
input
if
numpy
.
all
(
y
==
0.0
):
continue
...
...
@@ -3882,7 +3882,7 @@ def _is_1(expr):
try
:
v
=
get_constant_value
(
expr
)
return
numpy
.
allclose
(
v
,
1
)
except
Type
Error
:
except
NotConstant
Error
:
return
False
...
...
@@ -3892,7 +3892,7 @@ def _is_minus1(expr):
try
:
v
=
get_constant_value
(
expr
)
return
numpy
.
allclose
(
v
,
-
1
)
except
Type
Error
:
except
NotConstant
Error
:
return
False
#1+erf(x)=>erfc(-x)
...
...
@@ -4133,7 +4133,7 @@ def local_grad_log_erfc_neg(node):
try
:
cst2
=
get_constant_value
(
mul_neg
.
owner
.
inputs
[
0
])
except
Type
Error
:
except
NotConstant
Error
:
return
False
if
len
(
mul_neg
.
owner
.
inputs
)
==
2
:
...
...
@@ -4160,7 +4160,7 @@ def local_grad_log_erfc_neg(node):
x
=
erfc_x
try
:
cst
=
get_constant_value
(
erfc_x
.
owner
.
inputs
[
0
])
except
Type
Error
:
except
NotConstant
Error
:
return
False
if
cst2
!=
-
cst
*
2
:
return
False
...
...
theano/tensor/tests/test_basic.py
浏览文件 @
434dd96e
...
...
@@ -6176,7 +6176,7 @@ class T_get_constant_value(unittest.TestCase):
b
=
tensor
.
iscalar
()
a
=
tensor
.
stack
(
b
,
2
,
3
)
self
.
assertRaises
(
Type
Error
,
get_constant_value
,
a
[
0
])
self
.
assertRaises
(
tensor
.
basic
.
NotConstant
Error
,
get_constant_value
,
a
[
0
])
assert
get_constant_value
(
a
[
1
])
==
2
assert
get_constant_value
(
a
[
2
])
==
3
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论