Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
483cb9d3
提交
483cb9d3
authored
1月 15, 2013
作者:
lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1161 from goodfeli/rebase
Ready to merge: get rid of dangerous "TypeError = not constant" mechanism
上级
b7e9de0a
f9292ae7
全部展开
显示空白字符变更
内嵌
并排
正在显示
22 个修改的文件
包含
63 行增加
和
62 行删除
+63
-62
__init__.py
theano/__init__.py
+5
-4
gradient.py
theano/gradient.py
+3
-3
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+2
-2
ops.py
theano/sandbox/linalg/ops.py
+2
-2
scan.py
theano/sandbox/scan.py
+2
-2
scan.py
theano/sandbox/scan_module/scan.py
+1
-1
scan_utils.py
theano/sandbox/scan_module/scan_utils.py
+1
-1
scan.py
theano/scan_module/scan.py
+2
-2
scan_opt.py
theano/scan_module/scan_opt.py
+5
-5
scan_utils.py
theano/scan_module/scan_utils.py
+2
-2
basic.py
theano/tensor/basic.py
+0
-0
blas.py
theano/tensor/blas.py
+1
-1
extra_ops.py
theano/tensor/extra_ops.py
+2
-2
conv.py
theano/tensor/nnet/conv.py
+3
-3
nnet.py
theano/tensor/nnet/nnet.py
+2
-2
sigm.py
theano/tensor/nnet/sigm.py
+6
-6
test_conv.py
theano/tensor/nnet/tests/test_conv.py
+2
-2
opt.py
theano/tensor/opt.py
+0
-0
opt_uncanonicalize.py
theano/tensor/opt_uncanonicalize.py
+3
-3
test_basic.py
theano/tensor/tests/test_basic.py
+17
-17
test_elemwise.py
theano/tensor/tests/test_elemwise.py
+1
-1
test_tutorial.py
theano/tests/test_tutorial.py
+1
-1
没有找到文件。
theano/__init__.py
浏览文件 @
483cb9d3
...
...
@@ -163,7 +163,7 @@ def dot(l, r):
return
rval
def
get_constant_value
(
v
):
def
get_
scalar_
constant_value
(
v
):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
...
...
@@ -171,12 +171,13 @@ def get_constant_value(v):
If theano.sparse is also there, we will look over CSM op.
If `v` is not some view of constant data, then raise a TypeError.
If `v` is not some view of constant data, then raise a
tensor.basic.NotScalarConstantError.
"""
if
hasattr
(
theano
,
'sparse'
)
and
isinstance
(
v
.
type
,
theano
.
sparse
.
SparseType
):
if
v
.
owner
is
not
None
and
isinstance
(
v
.
owner
.
op
,
theano
.
sparse
.
CSM
):
data
=
v
.
owner
.
inputs
[
0
]
return
tensor
.
get_constant_value
(
data
)
return
tensor
.
get_constant_value
(
v
)
return
tensor
.
get_
scalar_
constant_value
(
data
)
return
tensor
.
get_
scalar_
constant_value
(
v
)
theano/gradient.py
浏览文件 @
483cb9d3
...
...
@@ -975,7 +975,7 @@ def _populate_grad_dict(var_to_app_to_idx,
msg
+=
"
%
s."
msg
%
(
str
(
node
.
op
),
str
(
term
),
str
(
type
(
term
)),
i
,
str
(
theano
.
get_constant_value
(
term
)))
i
,
str
(
theano
.
get_
scalar_
constant_value
(
term
)))
raise
ValueError
(
msg
)
...
...
@@ -1616,9 +1616,9 @@ def _is_zero(x):
no_constant_value
=
True
try
:
constant_value
=
theano
.
get_constant_value
(
x
)
constant_value
=
theano
.
get_
scalar_
constant_value
(
x
)
no_constant_value
=
False
except
Type
Error
:
except
theano
.
tensor
.
basic
.
NotScalarConstant
Error
:
pass
if
no_constant_value
:
...
...
theano/sandbox/cuda/basic_ops.py
浏览文件 @
483cb9d3
...
...
@@ -2691,8 +2691,8 @@ class GpuAlloc(GpuOp):
raise
TypeError
(
'Shape arguments must be integers'
,
s
)
# if s is constant 1, then we're broadcastable in that dim
try
:
const_shp
=
tensor
.
get_constant_value
(
s
)
except
Type
Error
:
const_shp
=
tensor
.
get_
scalar_
constant_value
(
s
)
except
tensor
.
NotScalarConstant
Error
:
const_shp
=
None
bcast
.
append
(
numpy
.
all
(
1
==
const_shp
))
otype
=
CudaNdarrayType
(
dtype
=
'float32'
,
broadcastable
=
bcast
)
...
...
theano/sandbox/linalg/ops.py
浏览文件 @
483cb9d3
...
...
@@ -214,8 +214,8 @@ def is_positive(v):
logger
.
debug
(
'is_positive:
%
s'
%
str
(
v
))
if
v
.
owner
and
v
.
owner
.
op
==
tensor
.
pow
:
try
:
exponent
=
tensor
.
get_constant_value
(
v
.
owner
.
inputs
[
1
])
except
Type
Error
:
exponent
=
tensor
.
get_
scalar_
constant_value
(
v
.
owner
.
inputs
[
1
])
except
tensor
.
basic
.
NotScalarConstant
Error
:
return
False
if
0
==
exponent
%
2
:
return
True
...
...
theano/sandbox/scan.py
浏览文件 @
483cb9d3
...
...
@@ -135,8 +135,8 @@ def scan(fn,
n_fixed_steps
=
int
(
n_steps
)
else
:
try
:
n_fixed_steps
=
opt
.
get_constant_value
(
n_steps
)
except
(
TypeError
,
AttributeError
)
:
n_fixed_steps
=
opt
.
get_
scalar_
constant_value
(
n_steps
)
except
tensor
.
basic
.
NotScalarConstantError
:
n_fixed_steps
=
None
# Check n_steps is an int
...
...
theano/sandbox/scan_module/scan.py
浏览文件 @
483cb9d3
...
...
@@ -335,7 +335,7 @@ def scan(fn,
T_value
=
int
(
n_steps
)
else
:
try
:
T_value
=
opt
.
get_constant_value
(
n_steps
)
T_value
=
opt
.
get_
scalar_
constant_value
(
n_steps
)
except
(
TypeError
,
AttributeError
):
T_value
=
None
...
...
theano/sandbox/scan_module/scan_utils.py
浏览文件 @
483cb9d3
...
...
@@ -24,7 +24,7 @@ from theano.compile.pfunc import rebuild_collect_shared
from
theano
import
gof
from
theano
import
tensor
,
scalar
from
theano.gof.python25
import
all
from
theano.tensor.basic
import
get_constant_value
from
theano.tensor.basic
import
get_
scalar_
constant_value
# Logging function for sending warning or info
...
...
theano/scan_module/scan.py
浏览文件 @
483cb9d3
...
...
@@ -363,8 +363,8 @@ def scan(fn,
n_fixed_steps
=
int
(
n_steps
)
else
:
try
:
n_fixed_steps
=
opt
.
get_constant_value
(
n_steps
)
except
(
TypeError
,
AttributeError
)
:
n_fixed_steps
=
opt
.
get_
scalar_
constant_value
(
n_steps
)
except
tensor
.
basic
.
NotScalarConstantError
:
n_fixed_steps
=
None
# Check n_steps is an int
...
...
theano/scan_module/scan_opt.py
浏览文件 @
483cb9d3
...
...
@@ -18,7 +18,7 @@ import numpy
import
theano
from
theano
import
tensor
from
theano.tensor
import
opt
,
get_constant_value
from
theano.tensor
import
opt
,
get_
scalar_
constant_value
from
theano
import
gof
from
theano.gof.python25
import
maxsize
,
any
from
theano.gof.opt
import
Optimizer
...
...
@@ -1164,14 +1164,14 @@ class ScanMerge(gof.Optimizer):
nsteps
=
node
.
inputs
[
0
]
try
:
nsteps
=
int
(
get_constant_value
(
nsteps
))
except
Type
Error
:
nsteps
=
int
(
get_
scalar_
constant_value
(
nsteps
))
except
tensor
.
NotScalarConstant
Error
:
pass
rep_nsteps
=
rep
.
inputs
[
0
]
try
:
rep_nsteps
=
int
(
get_constant_value
(
rep_nsteps
))
except
Type
Error
:
rep_nsteps
=
int
(
get_
scalar_
constant_value
(
rep_nsteps
))
except
tensor
.
NotScalarConstant
Error
:
pass
# Check to see if it is an input of a different node
...
...
theano/scan_module/scan_utils.py
浏览文件 @
483cb9d3
...
...
@@ -25,7 +25,7 @@ from theano.compile.pfunc import rebuild_collect_shared
from
theano
import
gof
from
theano
import
tensor
,
scalar
from
theano.gof.python25
import
all
,
OrderedDict
from
theano.tensor.basic
import
get_constant_value
from
theano.tensor.basic
import
get_
scalar_
constant_value
################ Utility Functions and Classes #######################
...
...
@@ -308,7 +308,7 @@ def isNaN_or_Inf_or_None(x):
isStr
=
False
if
not
isNaN
and
not
isInf
:
try
:
val
=
get_constant_value
(
x
)
val
=
get_
scalar_
constant_value
(
x
)
isInf
=
numpy
.
isinf
(
val
)
isNaN
=
numpy
.
isnan
(
val
)
except
Exception
:
...
...
theano/tensor/basic.py
浏览文件 @
483cb9d3
差异被折叠。
点击展开。
theano/tensor/blas.py
浏览文件 @
483cb9d3
...
...
@@ -1614,7 +1614,7 @@ def local_gemm_to_ger(node):
xv
=
x
.
dimshuffle
(
0
)
yv
=
y
.
dimshuffle
(
1
)
try
:
bval
=
T
.
get_constant_value
(
b
)
bval
=
T
.
get_
scalar_
constant_value
(
b
)
except
TypeError
:
# b isn't a constant, GEMM is doing useful pre-scaling
return
...
...
theano/tensor/extra_ops.py
浏览文件 @
483cb9d3
...
...
@@ -262,8 +262,8 @@ class RepeatOp(theano.Op):
broadcastable
=
[
False
]
else
:
try
:
const_reps
=
basic
.
get_constant_value
(
repeats
)
except
basic
.
NotConstantError
:
const_reps
=
basic
.
get_
scalar_
constant_value
(
repeats
)
except
basic
.
Not
Scalar
ConstantError
:
const_reps
=
None
if
const_reps
==
1
:
broadcastable
=
x
.
broadcastable
...
...
theano/tensor/nnet/conv.py
浏览文件 @
483cb9d3
...
...
@@ -15,7 +15,7 @@ import logging
import
numpy
import
theano
from
theano.tensor
import
(
as_tensor_variable
,
blas
,
get_constant_value
,
from
theano.tensor
import
(
as_tensor_variable
,
blas
,
get_
scalar_
constant_value
,
patternbroadcast
)
from
theano
import
OpenMPOp
,
config
from
theano.gof
import
Apply
...
...
@@ -90,7 +90,7 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
image_shape
=
list
(
image_shape
)
for
i
in
xrange
(
len
(
image_shape
)):
if
image_shape
[
i
]
is
not
None
:
image_shape
[
i
]
=
get_constant_value
(
image_shape
[
i
]
=
get_
scalar_
constant_value
(
as_tensor_variable
(
image_shape
[
i
]))
assert
str
(
image_shape
[
i
]
.
dtype
)
.
startswith
(
'int'
)
image_shape
[
i
]
=
int
(
image_shape
[
i
])
...
...
@@ -98,7 +98,7 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
filter_shape
=
list
(
filter_shape
)
for
i
in
xrange
(
len
(
filter_shape
)):
if
filter_shape
[
i
]
is
not
None
:
filter_shape
[
i
]
=
get_constant_value
(
filter_shape
[
i
]
=
get_
scalar_
constant_value
(
as_tensor_variable
(
filter_shape
[
i
]))
assert
str
(
filter_shape
[
i
]
.
dtype
)
.
startswith
(
'int'
)
filter_shape
[
i
]
=
int
(
filter_shape
[
i
])
...
...
theano/tensor/nnet/nnet.py
浏览文件 @
483cb9d3
...
...
@@ -1409,8 +1409,8 @@ def _check_rows_is_arange_len_labels(rows, labels):
def
_is_const
(
z
,
val
,
approx
=
False
):
try
:
maybe
=
opt
.
get_constant_value
(
z
)
except
Type
Error
:
maybe
=
opt
.
get_
scalar_
constant_value
(
z
)
except
tensor
.
NotScalarConstant
Error
:
return
False
if
approx
:
return
numpy
.
allclose
(
maybe
,
val
)
...
...
theano/tensor/nnet/sigm.py
浏览文件 @
483cb9d3
...
...
@@ -14,7 +14,7 @@ from theano.compile import optdb
from
theano.configparser
import
AddConfigVar
,
BoolParam
from
theano.printing
import
pprint
,
debugprint
from
theano.tensor
import
basic
as
tensor
from
theano.tensor
import
elemwise
,
opt
from
theano.tensor
import
elemwise
,
opt
,
NotScalarConstantError
############
...
...
@@ -136,9 +136,9 @@ def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1
"""
try
:
v
=
opt
.
get_constant_value
(
expr
)
v
=
opt
.
get_
scalar_
constant_value
(
expr
)
return
numpy
.
allclose
(
v
,
1
)
except
Type
Error
:
except
tensor
.
NotScalarConstant
Error
:
return
False
log1msigm_to_softplus
=
gof
.
PatternSub
(
...
...
@@ -275,9 +275,9 @@ def is_neg(var):
if
apply
.
op
==
tensor
.
mul
and
len
(
apply
.
inputs
)
>=
2
:
for
idx
,
mul_input
in
enumerate
(
apply
.
inputs
):
try
:
constant
=
opt
.
get_constant_value
(
mul_input
)
constant
=
opt
.
get_
scalar_
constant_value
(
mul_input
)
is_minus_1
=
numpy
.
allclose
(
constant
,
-
1
)
except
Type
Error
:
except
NotScalarConstant
Error
:
is_minus_1
=
False
if
is_minus_1
:
# Found a multiplication by -1.
...
...
@@ -647,7 +647,7 @@ def local_1msigmoid(node):
return
# graph is using both sigm and 1-sigm
if
sub_r
.
owner
and
sub_r
.
owner
.
op
==
sigmoid
:
try
:
val_l
=
opt
.
get_constant_value
(
sub_l
)
val_l
=
opt
.
get_
scalar_
constant_value
(
sub_l
)
except
Exception
,
e
:
return
if
numpy
.
allclose
(
numpy
.
sum
(
val_l
),
1
):
...
...
theano/tensor/nnet/tests/test_conv.py
浏览文件 @
483cb9d3
...
...
@@ -30,10 +30,10 @@ class TestConv2D(utt.InferShapeTester):
verify_grad
=
True
,
should_raise
=
False
):
if
N_image_shape
is
None
:
N_image_shape
=
[
T
.
get_constant_value
(
T
.
N_image_shape
=
[
T
.
get_
scalar_
constant_value
(
T
.
as_tensor_variable
(
x
))
for
x
in
image_shape
]
if
N_filter_shape
is
None
:
N_filter_shape
=
[
T
.
get_constant_value
(
T
.
N_filter_shape
=
[
T
.
get_
scalar_
constant_value
(
T
.
as_tensor_variable
(
x
))
for
x
in
filter_shape
]
if
input
is
None
:
...
...
theano/tensor/opt.py
浏览文件 @
483cb9d3
差异被折叠。
点击展开。
theano/tensor/opt_uncanonicalize.py
浏览文件 @
483cb9d3
...
...
@@ -41,7 +41,7 @@ from theano.gof.python25 import any, all
from
theano.gof.opt
import
Optimizer
from
theano.gof
import
InconsistencyError
,
toolbox
from
basic
import
get_
constant_value
from
basic
import
get_
scalar_constant_value
,
NotScalarConstantError
from
theano.tensor.opt
import
register_uncanonicalize
from
theano
import
scalar
as
scal
...
...
@@ -64,8 +64,8 @@ class MaxAndArgmaxOptimizer(Optimizer):
if
node
.
op
==
T
.
_max_and_argmax
:
if
len
(
node
.
outputs
[
1
]
.
clients
)
==
0
:
try
:
axis
=
get_constant_value
(
node
.
inputs
[
1
])
except
(
ValueError
,
TypeError
),
e
:
axis
=
get_
scalar_
constant_value
(
node
.
inputs
[
1
])
except
NotScalarConstantError
:
return
False
new
=
CAReduce
(
scal
.
maximum
,
axis
)(
node
.
inputs
[
0
])
...
...
theano/tensor/tests/test_basic.py
浏览文件 @
483cb9d3
...
...
@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
Reshape
,
row
,
scalar
,
scalars
,
second
,
smallest
,
stack
,
sub
,
Tensor
,
tensor_copy
,
tensordot
,
tensordot_grad
,
TensorType
,
unbroadcast
,
var
,
Join
,
shape
,
MaxAndArgmax
,
lscalar
,
zvector
,
exp
,
get_constant_value
,
ivector
,
reshape
,
scalar_from_tensor
,
scal
,
get_
scalar_
constant_value
,
ivector
,
reshape
,
scalar_from_tensor
,
scal
,
iscalars
,
arange
,
dscalars
,
fvector
,
imatrix
,
numeric_grad
,
opt
,
ComplexError
,
TensorDot
,
lvector
,
true_div
,
max
,
min
,
Split
,
roll
,
tile
,
patternbroadcast
,
Eye
,
Shape
,
Default
,
Dot
,
PermuteRowElements
,
...
...
@@ -2140,7 +2140,7 @@ class T_max_and_argmax(unittest.TestCase):
cost
=
argmax
(
x
,
axis
=
0
)
.
sum
()
value_error_raised
=
False
gx
=
grad
(
cost
,
x
)
val
=
tensor
.
get_constant_value
(
gx
)
val
=
tensor
.
get_
scalar_
constant_value
(
gx
)
assert
val
==
0.0
def
test_grad
(
self
):
...
...
@@ -6167,40 +6167,40 @@ def test_dimshuffle_duplicate():
assert
success
class
T_get_constant_value
(
unittest
.
TestCase
):
def
test_get_constant_value
(
self
):
class
T_get_
scalar_
constant_value
(
unittest
.
TestCase
):
def
test_get_
scalar_
constant_value
(
self
):
a
=
tensor
.
stack
(
1
,
2
,
3
)
assert
get_constant_value
(
a
[
0
])
==
1
assert
get_constant_value
(
a
[
1
])
==
2
assert
get_constant_value
(
a
[
2
])
==
3
assert
get_
scalar_
constant_value
(
a
[
0
])
==
1
assert
get_
scalar_
constant_value
(
a
[
1
])
==
2
assert
get_
scalar_
constant_value
(
a
[
2
])
==
3
b
=
tensor
.
iscalar
()
a
=
tensor
.
stack
(
b
,
2
,
3
)
self
.
assertRaises
(
TypeError
,
get
_constant_value
,
a
[
0
])
assert
get_constant_value
(
a
[
1
])
==
2
assert
get_constant_value
(
a
[
2
])
==
3
self
.
assertRaises
(
tensor
.
basic
.
NotScalarConstantError
,
get_scalar
_constant_value
,
a
[
0
])
assert
get_
scalar_
constant_value
(
a
[
1
])
==
2
assert
get_
scalar_
constant_value
(
a
[
2
])
==
3
# For now get_constant_value goes through only MakeVector and Join of
# For now get_
scalar_
constant_value goes through only MakeVector and Join of
# scalars.
v
=
tensor
.
ivector
()
a
=
tensor
.
stack
(
v
,
2
,
3
)
self
.
assertRaises
(
TypeError
,
get
_constant_value
,
a
[
0
])
self
.
assertRaises
(
TypeError
,
get
_constant_value
,
a
[
1
])
self
.
assertRaises
(
TypeError
,
get
_constant_value
,
a
[
2
])
self
.
assertRaises
(
tensor
.
NotScalarConstantError
,
get_scalar
_constant_value
,
a
[
0
])
self
.
assertRaises
(
tensor
.
NotScalarConstantError
,
get_scalar
_constant_value
,
a
[
1
])
self
.
assertRaises
(
tensor
.
NotScalarConstantError
,
get_scalar
_constant_value
,
a
[
2
])
# Test the case SubTensor(Shape(v)) when the dimensions
# is broadcastable.
v
=
tensor
.
row
()
assert
get_constant_value
(
v
.
shape
[
0
])
==
1
assert
get_
scalar_
constant_value
(
v
.
shape
[
0
])
==
1
def
test_subtensor_of_constant
(
self
):
c
=
constant
(
rand
(
5
))
for
i
in
range
(
c
.
value
.
shape
[
0
]):
assert
get_constant_value
(
c
[
i
])
==
c
.
value
[
i
]
assert
get_
scalar_
constant_value
(
c
[
i
])
==
c
.
value
[
i
]
c
=
constant
(
rand
(
5
,
5
))
for
i
in
range
(
c
.
value
.
shape
[
0
]):
for
j
in
range
(
c
.
value
.
shape
[
1
]):
assert
get_constant_value
(
c
[
i
,
j
])
==
c
.
value
[
i
,
j
]
assert
get_
scalar_
constant_value
(
c
[
i
,
j
])
==
c
.
value
[
i
,
j
]
class
T_as_tensor_variable
(
unittest
.
TestCase
):
...
...
theano/tensor/tests/test_elemwise.py
浏览文件 @
483cb9d3
...
...
@@ -856,7 +856,7 @@ def test_gt_grad():
"""A user test that failed.
Something about it made Elemwise.grad return something that was
too complicated for get_constant_value to recognize as being 0, so
too complicated for get_
scalar_
constant_value to recognize as being 0, so
gradient.grad reported that it was not a valid gradient of an
integer.
...
...
theano/tests/test_tutorial.py
浏览文件 @
483cb9d3
...
...
@@ -936,7 +936,7 @@ class T_fibby(unittest.TestCase):
if
node
.
op
==
fibby
:
x
=
node
.
inputs
[
0
]
try
:
if
numpy
.
all
(
0
==
get_constant_value
(
x
)):
if
numpy
.
all
(
0
==
get_
scalar_
constant_value
(
x
)):
return
[
x
]
except
TypeError
:
pass
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论