Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
483cb9d3
提交
483cb9d3
authored
1月 15, 2013
作者:
lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1161 from goodfeli/rebase
Ready to merge: get rid of dangerous "TypeError = not constant" mechanism
上级
b7e9de0a
f9292ae7
隐藏空白字符变更
内嵌
并排
正在显示
22 个修改的文件
包含
205 行增加
和
172 行删除
+205
-172
__init__.py
theano/__init__.py
+5
-4
gradient.py
theano/gradient.py
+3
-3
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+2
-2
ops.py
theano/sandbox/linalg/ops.py
+2
-2
scan.py
theano/sandbox/scan.py
+2
-2
scan.py
theano/sandbox/scan_module/scan.py
+1
-1
scan_utils.py
theano/sandbox/scan_module/scan_utils.py
+1
-1
scan.py
theano/scan_module/scan.py
+2
-2
scan_opt.py
theano/scan_module/scan_opt.py
+5
-5
scan_utils.py
theano/scan_module/scan_utils.py
+2
-2
basic.py
theano/tensor/basic.py
+85
-59
blas.py
theano/tensor/blas.py
+1
-1
extra_ops.py
theano/tensor/extra_ops.py
+3
-3
conv.py
theano/tensor/nnet/conv.py
+3
-3
nnet.py
theano/tensor/nnet/nnet.py
+2
-2
sigm.py
theano/tensor/nnet/sigm.py
+6
-6
test_conv.py
theano/tensor/nnet/tests/test_conv.py
+2
-2
opt.py
theano/tensor/opt.py
+56
-50
opt_uncanonicalize.py
theano/tensor/opt_uncanonicalize.py
+3
-3
test_basic.py
theano/tensor/tests/test_basic.py
+17
-17
test_elemwise.py
theano/tensor/tests/test_elemwise.py
+1
-1
test_tutorial.py
theano/tests/test_tutorial.py
+1
-1
没有找到文件。
theano/__init__.py
浏览文件 @
483cb9d3
...
...
@@ -163,7 +163,7 @@ def dot(l, r):
return
rval
def
get_constant_value
(
v
):
def
get_
scalar_
constant_value
(
v
):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
...
...
@@ -171,12 +171,13 @@ def get_constant_value(v):
If theano.sparse is also there, we will look over CSM op.
If `v` is not some view of constant data, then raise a TypeError.
If `v` is not some view of constant data, then raise a
tensor.basic.NotScalarConstantError.
"""
if
hasattr
(
theano
,
'sparse'
)
and
isinstance
(
v
.
type
,
theano
.
sparse
.
SparseType
):
if
v
.
owner
is
not
None
and
isinstance
(
v
.
owner
.
op
,
theano
.
sparse
.
CSM
):
data
=
v
.
owner
.
inputs
[
0
]
return
tensor
.
get_constant_value
(
data
)
return
tensor
.
get_constant_value
(
v
)
return
tensor
.
get_
scalar_
constant_value
(
data
)
return
tensor
.
get_
scalar_
constant_value
(
v
)
theano/gradient.py
浏览文件 @
483cb9d3
...
...
@@ -975,7 +975,7 @@ def _populate_grad_dict(var_to_app_to_idx,
msg
+=
"
%
s."
msg
%
(
str
(
node
.
op
),
str
(
term
),
str
(
type
(
term
)),
i
,
str
(
theano
.
get_constant_value
(
term
)))
i
,
str
(
theano
.
get_
scalar_
constant_value
(
term
)))
raise
ValueError
(
msg
)
...
...
@@ -1616,9 +1616,9 @@ def _is_zero(x):
no_constant_value
=
True
try
:
constant_value
=
theano
.
get_constant_value
(
x
)
constant_value
=
theano
.
get_
scalar_
constant_value
(
x
)
no_constant_value
=
False
except
Type
Error
:
except
theano
.
tensor
.
basic
.
NotScalarConstant
Error
:
pass
if
no_constant_value
:
...
...
theano/sandbox/cuda/basic_ops.py
浏览文件 @
483cb9d3
...
...
@@ -2691,8 +2691,8 @@ class GpuAlloc(GpuOp):
raise
TypeError
(
'Shape arguments must be integers'
,
s
)
# if s is constant 1, then we're broadcastable in that dim
try
:
const_shp
=
tensor
.
get_constant_value
(
s
)
except
Type
Error
:
const_shp
=
tensor
.
get_
scalar_
constant_value
(
s
)
except
tensor
.
NotScalarConstant
Error
:
const_shp
=
None
bcast
.
append
(
numpy
.
all
(
1
==
const_shp
))
otype
=
CudaNdarrayType
(
dtype
=
'float32'
,
broadcastable
=
bcast
)
...
...
theano/sandbox/linalg/ops.py
浏览文件 @
483cb9d3
...
...
@@ -214,8 +214,8 @@ def is_positive(v):
logger
.
debug
(
'is_positive:
%
s'
%
str
(
v
))
if
v
.
owner
and
v
.
owner
.
op
==
tensor
.
pow
:
try
:
exponent
=
tensor
.
get_constant_value
(
v
.
owner
.
inputs
[
1
])
except
Type
Error
:
exponent
=
tensor
.
get_
scalar_
constant_value
(
v
.
owner
.
inputs
[
1
])
except
tensor
.
basic
.
NotScalarConstant
Error
:
return
False
if
0
==
exponent
%
2
:
return
True
...
...
theano/sandbox/scan.py
浏览文件 @
483cb9d3
...
...
@@ -135,8 +135,8 @@ def scan(fn,
n_fixed_steps
=
int
(
n_steps
)
else
:
try
:
n_fixed_steps
=
opt
.
get_constant_value
(
n_steps
)
except
(
TypeError
,
AttributeError
)
:
n_fixed_steps
=
opt
.
get_
scalar_
constant_value
(
n_steps
)
except
tensor
.
basic
.
NotScalarConstantError
:
n_fixed_steps
=
None
# Check n_steps is an int
...
...
theano/sandbox/scan_module/scan.py
浏览文件 @
483cb9d3
...
...
@@ -335,7 +335,7 @@ def scan(fn,
T_value
=
int
(
n_steps
)
else
:
try
:
T_value
=
opt
.
get_constant_value
(
n_steps
)
T_value
=
opt
.
get_
scalar_
constant_value
(
n_steps
)
except
(
TypeError
,
AttributeError
):
T_value
=
None
...
...
theano/sandbox/scan_module/scan_utils.py
浏览文件 @
483cb9d3
...
...
@@ -24,7 +24,7 @@ from theano.compile.pfunc import rebuild_collect_shared
from
theano
import
gof
from
theano
import
tensor
,
scalar
from
theano.gof.python25
import
all
from
theano.tensor.basic
import
get_constant_value
from
theano.tensor.basic
import
get_
scalar_
constant_value
# Logging function for sending warning or info
...
...
theano/scan_module/scan.py
浏览文件 @
483cb9d3
...
...
@@ -363,8 +363,8 @@ def scan(fn,
n_fixed_steps
=
int
(
n_steps
)
else
:
try
:
n_fixed_steps
=
opt
.
get_constant_value
(
n_steps
)
except
(
TypeError
,
AttributeError
)
:
n_fixed_steps
=
opt
.
get_
scalar_
constant_value
(
n_steps
)
except
tensor
.
basic
.
NotScalarConstantError
:
n_fixed_steps
=
None
# Check n_steps is an int
...
...
theano/scan_module/scan_opt.py
浏览文件 @
483cb9d3
...
...
@@ -18,7 +18,7 @@ import numpy
import
theano
from
theano
import
tensor
from
theano.tensor
import
opt
,
get_constant_value
from
theano.tensor
import
opt
,
get_
scalar_
constant_value
from
theano
import
gof
from
theano.gof.python25
import
maxsize
,
any
from
theano.gof.opt
import
Optimizer
...
...
@@ -1164,14 +1164,14 @@ class ScanMerge(gof.Optimizer):
nsteps
=
node
.
inputs
[
0
]
try
:
nsteps
=
int
(
get_constant_value
(
nsteps
))
except
Type
Error
:
nsteps
=
int
(
get_
scalar_
constant_value
(
nsteps
))
except
tensor
.
NotScalarConstant
Error
:
pass
rep_nsteps
=
rep
.
inputs
[
0
]
try
:
rep_nsteps
=
int
(
get_constant_value
(
rep_nsteps
))
except
Type
Error
:
rep_nsteps
=
int
(
get_
scalar_
constant_value
(
rep_nsteps
))
except
tensor
.
NotScalarConstant
Error
:
pass
# Check to see if it is an input of a different node
...
...
theano/scan_module/scan_utils.py
浏览文件 @
483cb9d3
...
...
@@ -25,7 +25,7 @@ from theano.compile.pfunc import rebuild_collect_shared
from
theano
import
gof
from
theano
import
tensor
,
scalar
from
theano.gof.python25
import
all
,
OrderedDict
from
theano.tensor.basic
import
get_constant_value
from
theano.tensor.basic
import
get_
scalar_
constant_value
################ Utility Functions and Classes #######################
...
...
@@ -308,7 +308,7 @@ def isNaN_or_Inf_or_None(x):
isStr
=
False
if
not
isNaN
and
not
isInf
:
try
:
val
=
get_constant_value
(
x
)
val
=
get_
scalar_
constant_value
(
x
)
isInf
=
numpy
.
isinf
(
val
)
isNaN
=
numpy
.
isnan
(
val
)
except
Exception
:
...
...
theano/tensor/basic.py
浏览文件 @
483cb9d3
...
...
@@ -463,70 +463,88 @@ def _allclose(a, b, rtol=None, atol=None):
return
numpy
.
allclose
(
a
,
b
,
atol
=
atol_
,
rtol
=
rtol_
)
class
Not
ConstantError
(
TypeError
):
class
Not
ScalarConstantError
(
Exception
):
"""
Raised by get_constant_value if called on something that is
not constant.
For now it is a TypeError, to maintain the old interface
that get_constant_value should raise a TypeError in this
situation. However, this is unsafe because get_constant_value
could inadvertently raise a TypeError if it has a bug.
So we should eventually make NotConstantError derive
from Exception directly, and modify all code that uses
get_constant_value to catch this more specific exception.
Raised by get_scalar_constant_value if called on something that is
not a scalar constant.
"""
class
EmptyConstantError
(
NotScalarConstantError
):
"""
Raised by get_scalar_const_value if called on something that is a
zero dimensional constant.
"""
pass
def
get_constant_value
(
v
):
def
get_
scalar_
constant_value
(
v
):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
this function digs through them.
If `v` is not some view of constant
data, then raise a Not
ConstantError.
If `v` is not some view of constant
scalar data, then raise a NotScalar
ConstantError.
:note: There may be another function similar to this one in the
code, but I'm not sure where it is.
"""
if
isinstance
(
v
,
Constant
):
if
getattr
(
v
.
tag
,
'unique_value'
,
None
)
is
not
None
:
data
=
v
.
tag
.
unique_value
else
:
data
=
v
.
data
if
v
is
None
:
# None is not a scalar (and many uses of this function seem to depend
# on passing it None)
raise
NotScalarConstantError
()
if
isinstance
(
v
,
(
int
,
float
)):
return
numpy
.
asarray
(
v
)
def
numpy_scalar
(
n
):
""" Return a scalar stored in a numpy ndarray, or raise
NotScalarConstantError if the numpy ndarray is not a scalar
"""
# handle case where data is numpy.array([])
if
hasattr
(
data
,
'shape'
)
and
len
(
data
.
shape
)
==
0
or
\
__builtins__
[
'max'
](
data
.
shape
)
==
0
:
if
data
.
ndim
>
0
and
(
len
(
data
.
shape
)
==
0
or
__builtins__
[
'max'
](
data
.
shape
)
==
0
)
:
assert
numpy
.
all
(
numpy
.
array
([])
==
data
)
r
eturn
data
r
aise
EmptyConstantError
()
try
:
numpy
.
complex
(
data
)
# works for all numeric scalars
return
data
except
Exception
:
raise
NotConstantError
(
raise
Not
Scalar
ConstantError
(
'v.data is non-numeric, non-scalar, or has more than one'
' unique value'
,
v
)
' unique value'
,
n
)
if
isinstance
(
v
,
numpy
.
ndarray
):
return
numpy_scalar
(
v
)
if
isinstance
(
v
,
Constant
):
if
getattr
(
v
.
tag
,
'unique_value'
,
None
)
is
not
None
:
data
=
v
.
tag
.
unique_value
else
:
data
=
v
.
data
return
numpy_scalar
(
data
)
if
v
.
owner
:
if
isinstance
(
v
.
owner
.
op
,
Alloc
):
return
get_constant_value
(
v
.
owner
.
inputs
[
0
])
return
get_
scalar_
constant_value
(
v
.
owner
.
inputs
[
0
])
if
isinstance
(
v
.
owner
.
op
,
DimShuffle
):
return
get_constant_value
(
v
.
owner
.
inputs
[
0
])
return
get_
scalar_
constant_value
(
v
.
owner
.
inputs
[
0
])
if
isinstance
(
v
.
owner
.
op
,
Rebroadcast
):
return
get_constant_value
(
v
.
owner
.
inputs
[
0
])
return
get_
scalar_
constant_value
(
v
.
owner
.
inputs
[
0
])
if
isinstance
(
v
.
owner
.
op
,
Elemwise
)
and
\
isinstance
(
v
.
owner
.
op
.
scalar_op
,
scal
.
Second
):
shape
,
val
=
v
.
owner
.
inputs
return
get_constant_value
(
val
)
return
get_
scalar_
constant_value
(
val
)
if
isinstance
(
v
.
owner
.
op
,
scal
.
Second
):
x
,
y
=
v
.
owner
.
inputs
return
get_constant_value
(
y
)
return
get_
scalar_
constant_value
(
y
)
# Don't act as the constant_folding optimization here as this
# fct is used too early in the optimization phase. This would
# mess with the stabilization optimization.
if
(
isinstance
(
v
.
owner
.
op
,
Elemwise
)
and
isinstance
(
v
.
owner
.
op
.
scalar_op
,
scal
.
Cast
))
or
\
isinstance
(
v
.
owner
.
op
,
scal
.
Cast
):
const
=
get_constant_value
(
v
.
owner
.
inputs
[
0
])
const
=
get_
scalar_
constant_value
(
v
.
owner
.
inputs
[
0
])
ret
=
[[
None
]]
v
.
owner
.
op
.
perform
(
v
.
owner
,
[
const
],
ret
)
return
ret
[
0
][
0
]
...
...
@@ -563,7 +581,7 @@ def get_constant_value(v):
# axis.
ret
=
v
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
v
.
owner
.
op
.
idx_list
[
0
]
+
1
]
ret
=
get_constant_value
(
ret
)
ret
=
get_
scalar_
constant_value
(
ret
)
# join can cast implicitly its input in some case.
return
theano
.
_asarray
(
ret
,
dtype
=
v
.
type
.
dtype
)
if
(
v
.
owner
.
inputs
[
0
]
.
owner
and
...
...
@@ -576,7 +594,7 @@ def get_constant_value(v):
len
(
v
.
owner
.
op
.
idx_list
)
==
1
):
ret
=
v
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
v
.
owner
.
op
.
idx_list
[
0
]]
ret
=
get_constant_value
(
ret
)
ret
=
get_
scalar_
constant_value
(
ret
)
# MakeVector can cast implicitly its input in some case.
return
theano
.
_asarray
(
ret
,
dtype
=
v
.
type
.
dtype
)
...
...
@@ -589,7 +607,7 @@ def get_constant_value(v):
v
.
owner
.
op
.
idx_list
[
0
]]:
return
numpy
.
asarray
(
1
)
raise
NotConstantError
(
v
)
raise
Not
Scalar
ConstantError
(
v
)
class
TensorType
(
Type
):
...
...
@@ -1832,8 +1850,8 @@ class _tensor_py_operators:
# TO TRUMP NUMPY OPERATORS
__array_priority__
=
1000
def
get_constant_value
(
self
):
return
get_constant_value
(
self
)
def
get_
scalar_
constant_value
(
self
):
return
get_
scalar_
constant_value
(
self
)
def
zeros_like
(
model
):
return
zeros_like
(
model
)
...
...
@@ -2361,10 +2379,10 @@ class SpecifyShape(Op):
new_shape
=
[]
for
dim
in
xrange
(
node
.
inputs
[
0
]
.
ndim
):
try
:
s
=
get_constant_value
(
node
.
inputs
[
1
][
dim
])
s
=
get_
scalar_
constant_value
(
node
.
inputs
[
1
][
dim
])
s
=
as_tensor_variable
(
s
)
new_shape
.
append
(
s
)
except
Type
Error
:
except
NotScalarConstant
Error
:
new_shape
.
append
(
node
.
inputs
[
1
][
dim
])
assert
len
(
new_shape
)
==
len
(
xshape
)
...
...
@@ -2656,14 +2674,22 @@ def max(x, axis=None, keepdims=False):
:note: we return an error as numpy when we reduce a dim with a shape of 0
"""
if
isinstance
(
axis
,
(
list
,
tuple
))
and
len
(
axis
)
>
1
:
# We have a choice of implementing this call with the
# CAReduce op or the MaxAndArgmax op.
# MaxAndArgmax supports grad and Rop, so we prefer to use that.
# CAReduce is faster, but optimizations will replace MaxAndArgmax[0]
# with CAReduce at compile time, so at this stage the important
# thing is supporting all user interface features, not speed.
# Some cases can be implemented only with CAReduce.
# We thus prefer to use MaxAndArgmax, if possible. It does not
# support all axis arguments, so we may need to fall back to CAReduce.
try
:
out
=
max_and_argmax
(
x
,
axis
)[
0
]
except
Exception
:
out
=
CAReduce
(
scal
.
maximum
,
axis
)(
x
)
else
:
try
:
const
=
get_constant_value
(
axis
)
out
=
CAReduce
(
scal
.
maximum
,
list
(
const
))(
x
)
except
Exception
:
out
=
max_and_argmax
(
x
,
axis
)[
0
]
if
keepdims
:
out
=
makeKeepDims
(
x
,
out
,
axis
)
...
...
@@ -3271,8 +3297,8 @@ class Alloc(gof.Op):
(
i
,
s_as_str
))
# if s is constant 1, then we're broadcastable in that dim
try
:
const_shp
=
get_constant_value
(
s
)
except
Type
Error
:
const_shp
=
get_
scalar_
constant_value
(
s
)
except
NotScalarConstant
Error
:
const_shp
=
None
bcast
.
append
(
numpy
.
all
(
1
==
const_shp
))
otype
=
TensorType
(
dtype
=
v
.
dtype
,
broadcastable
=
bcast
)
...
...
@@ -3811,16 +3837,16 @@ def get_idx_list(inputs, idx_list):
def
extract_constant
(
x
):
'''
This function is basically a call to tensor.get_constant_value. The
This function is basically a call to tensor.get_
scalar_
constant_value. The
main difference is the behaviour in case of failure. While
get_constant_value raises an TypeError, this function returns x,
get_
scalar_
constant_value raises an TypeError, this function returns x,
as a tensor if possible. If x is a ScalarVariable from a
scalar_from_tensor, we remove the conversion. If x is just a
ScalarVariable, we convert it to a tensor with tensor_from_scalar.
'''
try
:
x
=
get_constant_value
(
x
)
except
Exception
:
x
=
get_
scalar_
constant_value
(
x
)
except
NotScalarConstantError
:
pass
if
(
isinstance
(
x
,
scal
.
ScalarVariable
)
or
isinstance
(
x
,
scal
.
sharedvar
.
ScalarSharedVariable
)):
...
...
@@ -5419,11 +5445,11 @@ class Join(Op):
# Axis can also be a constant
if
not
isinstance
(
axis
,
int
):
try
:
# Note : `get_constant_value` returns a ndarray not a
# Note : `get_
scalar_
constant_value` returns a ndarray not a
# int
axis
=
int
(
get_constant_value
(
axis
))
axis
=
int
(
get_
scalar_
constant_value
(
axis
))
except
Type
Error
:
except
NotScalarConstant
Error
:
pass
if
isinstance
(
axis
,
int
):
# Basically, broadcastable -> length 1, but the
...
...
@@ -5790,9 +5816,9 @@ class Reshape(Op):
# Try to see if we can infer that y has a constant value of 1.
# If so, that dimension should be broadcastable.
try
:
bcasts
[
index
]
=
(
hasattr
(
y
,
'get_constant_value'
)
and
y
.
get_constant_value
()
==
1
)
except
Type
Error
:
bcasts
[
index
]
=
(
hasattr
(
y
,
'get_
scalar_
constant_value'
)
and
y
.
get_
scalar_
constant_value
()
==
1
)
except
NotScalarConstant
Error
:
pass
return
gof
.
Apply
(
self
,
[
x
,
shp
],
[
tensor
(
x
.
type
.
dtype
,
bcasts
)])
...
...
@@ -5865,10 +5891,10 @@ class Reshape(Op):
for
i
in
xrange
(
self
.
ndim
):
default_os_i
=
theano
.
tensor
.
opt
.
Shape_i
(
i
)(
node
.
outputs
[
0
])
try
:
os_i
=
get_constant_value
(
node
.
inputs
[
1
][
i
])
.
item
()
os_i
=
get_
scalar_
constant_value
(
node
.
inputs
[
1
][
i
])
.
item
()
if
os_i
==
-
1
:
os_i
=
default_os_i
except
Type
Error
:
except
NotScalarConstant
Error
:
os_i
=
default_os_i
oshape
.
append
(
os_i
)
return
[
tuple
(
oshape
)]
...
...
@@ -6148,9 +6174,9 @@ class ARange(Op):
def
is_constant_value
(
var
,
value
):
try
:
v
=
get_constant_value
(
var
)
v
=
get_
scalar_
constant_value
(
var
)
return
numpy
.
all
(
v
==
value
)
except
Exception
:
except
NotScalarConstantError
:
pass
return
False
...
...
theano/tensor/blas.py
浏览文件 @
483cb9d3
...
...
@@ -1614,7 +1614,7 @@ def local_gemm_to_ger(node):
xv
=
x
.
dimshuffle
(
0
)
yv
=
y
.
dimshuffle
(
1
)
try
:
bval
=
T
.
get_constant_value
(
b
)
bval
=
T
.
get_
scalar_
constant_value
(
b
)
except
TypeError
:
# b isn't a constant, GEMM is doing useful pre-scaling
return
...
...
theano/tensor/extra_ops.py
浏览文件 @
483cb9d3
...
...
@@ -262,15 +262,15 @@ class RepeatOp(theano.Op):
broadcastable
=
[
False
]
else
:
try
:
const_reps
=
basic
.
get_constant_value
(
repeats
)
except
basic
.
NotConstantError
:
const_reps
=
basic
.
get_
scalar_
constant_value
(
repeats
)
except
basic
.
Not
Scalar
ConstantError
:
const_reps
=
None
if
const_reps
==
1
:
broadcastable
=
x
.
broadcastable
else
:
broadcastable
=
list
(
x
.
broadcastable
)
broadcastable
[
self
.
axis
]
=
False
out_type
=
theano
.
tensor
.
TensorType
(
x
.
dtype
,
broadcastable
)
return
theano
.
Apply
(
self
,
[
x
,
repeats
],
[
out_type
()])
...
...
theano/tensor/nnet/conv.py
浏览文件 @
483cb9d3
...
...
@@ -15,7 +15,7 @@ import logging
import
numpy
import
theano
from
theano.tensor
import
(
as_tensor_variable
,
blas
,
get_constant_value
,
from
theano.tensor
import
(
as_tensor_variable
,
blas
,
get_
scalar_
constant_value
,
patternbroadcast
)
from
theano
import
OpenMPOp
,
config
from
theano.gof
import
Apply
...
...
@@ -90,7 +90,7 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
image_shape
=
list
(
image_shape
)
for
i
in
xrange
(
len
(
image_shape
)):
if
image_shape
[
i
]
is
not
None
:
image_shape
[
i
]
=
get_constant_value
(
image_shape
[
i
]
=
get_
scalar_
constant_value
(
as_tensor_variable
(
image_shape
[
i
]))
assert
str
(
image_shape
[
i
]
.
dtype
)
.
startswith
(
'int'
)
image_shape
[
i
]
=
int
(
image_shape
[
i
])
...
...
@@ -98,7 +98,7 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
filter_shape
=
list
(
filter_shape
)
for
i
in
xrange
(
len
(
filter_shape
)):
if
filter_shape
[
i
]
is
not
None
:
filter_shape
[
i
]
=
get_constant_value
(
filter_shape
[
i
]
=
get_
scalar_
constant_value
(
as_tensor_variable
(
filter_shape
[
i
]))
assert
str
(
filter_shape
[
i
]
.
dtype
)
.
startswith
(
'int'
)
filter_shape
[
i
]
=
int
(
filter_shape
[
i
])
...
...
theano/tensor/nnet/nnet.py
浏览文件 @
483cb9d3
...
...
@@ -1409,8 +1409,8 @@ def _check_rows_is_arange_len_labels(rows, labels):
def
_is_const
(
z
,
val
,
approx
=
False
):
try
:
maybe
=
opt
.
get_constant_value
(
z
)
except
Type
Error
:
maybe
=
opt
.
get_
scalar_
constant_value
(
z
)
except
tensor
.
NotScalarConstant
Error
:
return
False
if
approx
:
return
numpy
.
allclose
(
maybe
,
val
)
...
...
theano/tensor/nnet/sigm.py
浏览文件 @
483cb9d3
...
...
@@ -14,7 +14,7 @@ from theano.compile import optdb
from
theano.configparser
import
AddConfigVar
,
BoolParam
from
theano.printing
import
pprint
,
debugprint
from
theano.tensor
import
basic
as
tensor
from
theano.tensor
import
elemwise
,
opt
from
theano.tensor
import
elemwise
,
opt
,
NotScalarConstantError
############
...
...
@@ -136,9 +136,9 @@ def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1
"""
try
:
v
=
opt
.
get_constant_value
(
expr
)
v
=
opt
.
get_
scalar_
constant_value
(
expr
)
return
numpy
.
allclose
(
v
,
1
)
except
Type
Error
:
except
tensor
.
NotScalarConstant
Error
:
return
False
log1msigm_to_softplus
=
gof
.
PatternSub
(
...
...
@@ -275,9 +275,9 @@ def is_neg(var):
if
apply
.
op
==
tensor
.
mul
and
len
(
apply
.
inputs
)
>=
2
:
for
idx
,
mul_input
in
enumerate
(
apply
.
inputs
):
try
:
constant
=
opt
.
get_constant_value
(
mul_input
)
constant
=
opt
.
get_
scalar_
constant_value
(
mul_input
)
is_minus_1
=
numpy
.
allclose
(
constant
,
-
1
)
except
Type
Error
:
except
NotScalarConstant
Error
:
is_minus_1
=
False
if
is_minus_1
:
# Found a multiplication by -1.
...
...
@@ -647,7 +647,7 @@ def local_1msigmoid(node):
return
# graph is using both sigm and 1-sigm
if
sub_r
.
owner
and
sub_r
.
owner
.
op
==
sigmoid
:
try
:
val_l
=
opt
.
get_constant_value
(
sub_l
)
val_l
=
opt
.
get_
scalar_
constant_value
(
sub_l
)
except
Exception
,
e
:
return
if
numpy
.
allclose
(
numpy
.
sum
(
val_l
),
1
):
...
...
theano/tensor/nnet/tests/test_conv.py
浏览文件 @
483cb9d3
...
...
@@ -30,10 +30,10 @@ class TestConv2D(utt.InferShapeTester):
verify_grad
=
True
,
should_raise
=
False
):
if
N_image_shape
is
None
:
N_image_shape
=
[
T
.
get_constant_value
(
T
.
N_image_shape
=
[
T
.
get_
scalar_
constant_value
(
T
.
as_tensor_variable
(
x
))
for
x
in
image_shape
]
if
N_filter_shape
is
None
:
N_filter_shape
=
[
T
.
get_constant_value
(
T
.
N_filter_shape
=
[
T
.
get_
scalar_
constant_value
(
T
.
as_tensor_variable
(
x
))
for
x
in
filter_shape
]
if
input
is
None
:
...
...
theano/tensor/opt.py
浏览文件 @
483cb9d3
...
...
@@ -33,7 +33,7 @@ from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer
)
from
theano.gof.opt
import
merge_optimizer
from
theano.gof
import
toolbox
,
DestroyHandler
from
basic
import
get_
constant_value
,
Shape
Error
from
basic
import
get_
scalar_constant_value
,
ShapeError
,
NotScalarConstant
Error
theano
.
configparser
.
AddConfigVar
(
'on_shape_error'
,
...
...
@@ -92,10 +92,10 @@ def scalarconsts_rest(inputs):
nonconsts
=
[]
for
i
in
inputs
:
try
:
v
=
get_constant_value
(
i
)
v
=
get_
scalar_
constant_value
(
i
)
consts
.
append
(
v
)
origconsts
.
append
(
i
)
except
Exception
:
except
NotScalarConstantError
:
nonconsts
.
append
(
i
)
return
consts
,
origconsts
,
nonconsts
...
...
@@ -125,7 +125,13 @@ def broadcast_like(value, template, fgraph, dtype=None):
if
rval
.
broadcastable
[
i
]
and
not
template
.
broadcastable
[
i
]])
assert
rval
.
type
.
dtype
==
dtype
assert
rval
.
type
.
broadcastable
==
template
.
broadcastable
if
rval
.
type
.
broadcastable
!=
template
.
broadcastable
:
raise
AssertionError
(
"rval.type.broadcastable is "
+
str
(
rval
.
type
.
broadcastable
)
+
" but template.broadcastable is"
+
str
(
template
.
broadcastable
))
return
rval
...
...
@@ -322,15 +328,15 @@ def local_0_dot_x(node):
y
=
node
.
inputs
[
1
]
replace
=
False
try
:
if
get_constant_value
(
x
)
==
0
:
if
get_
scalar_
constant_value
(
x
)
==
0
:
replace
=
True
except
Type
Error
:
except
NotScalarConstant
Error
:
pass
try
:
if
get_constant_value
(
y
)
==
0
:
if
get_
scalar_
constant_value
(
y
)
==
0
:
replace
=
True
except
Type
Error
:
except
NotScalarConstant
Error
:
pass
if
replace
:
...
...
@@ -1177,9 +1183,9 @@ def local_subtensor_make_vector(node):
elif
isinstance
(
idx
,
Variable
):
# if it is a constant we can do something with it
try
:
v
=
get_constant_value
(
idx
)
v
=
get_
scalar_
constant_value
(
idx
)
return
[
x
.
owner
.
inputs
[
v
]]
except
Exception
:
except
NotScalarConstantError
:
pass
else
:
# it is a slice of ints and/or Variables
...
...
@@ -1315,13 +1321,13 @@ def local_remove_useless_assert(node):
cond
=
[]
for
c
in
node
.
inputs
[
1
:]:
try
:
const
=
get_constant_value
(
c
)
const
=
get_
scalar_
constant_value
(
c
)
if
0
!=
const
.
ndim
or
const
==
0
:
#Should we raise an error here? How to be sure it
#is not catched?
cond
.
append
(
c
)
except
Type
Error
:
except
NotScalarConstant
Error
:
cond
.
append
(
c
)
if
len
(
cond
)
==
0
:
...
...
@@ -1477,7 +1483,7 @@ def local_upcast_elemwise_constant_inputs(node):
else
:
try
:
# works only for scalars
cval_i
=
get_constant_value
(
i
)
cval_i
=
get_
scalar_
constant_value
(
i
)
if
all
(
i
.
broadcastable
):
new_inputs
.
append
(
T
.
shape_padleft
(
T
.
cast
(
cval_i
,
output_dtype
),
...
...
@@ -1490,7 +1496,7 @@ def local_upcast_elemwise_constant_inputs(node):
*
[
shape_i
(
d
)(
i
)
for
d
in
xrange
(
i
.
ndim
)]))
#print >> sys.stderr, "AAA",
#*[Shape_i(d)(i) for d in xrange(i.ndim)]
except
Type
Error
:
except
NotScalarConstant
Error
:
#for the case of a non-scalar
if
isinstance
(
i
,
T
.
TensorConstant
):
new_inputs
.
append
(
T
.
cast
(
i
,
output_dtype
))
...
...
@@ -1550,8 +1556,8 @@ def local_useless_subtensor(node):
length_pos
=
shape_of
[
node
.
inputs
[
0
]][
pos
]
try
:
length_pos_data
=
get_constant_value
(
length_pos
)
except
Type
Error
:
length_pos_data
=
get_
scalar_
constant_value
(
length_pos
)
except
NotScalarConstant
Error
:
pass
if
isinstance
(
idx
.
stop
,
int
):
...
...
@@ -2032,9 +2038,9 @@ def local_incsubtensor_of_allocs(node):
y
=
node
.
inputs
[
1
]
replace
=
False
try
:
if
get_constant_value
(
y
)
==
0
:
if
get_
scalar_
constant_value
(
y
)
==
0
:
replace
=
True
except
Type
Error
:
except
NotScalarConstant
Error
:
pass
if
replace
:
...
...
@@ -2059,13 +2065,13 @@ def local_setsubtensor_of_allocs(node):
replace_y
=
None
try
:
replace_x
=
get_constant_value
(
x
)
except
Type
Error
:
replace_x
=
get_
scalar_
constant_value
(
x
)
except
NotScalarConstant
Error
:
pass
try
:
replace_y
=
get_constant_value
(
y
)
except
Type
Error
:
replace_y
=
get_
scalar_
constant_value
(
y
)
except
NotScalarConstant
Error
:
pass
if
(
replace_x
==
replace_y
and
...
...
@@ -2253,24 +2259,24 @@ def local_mul_switch_sink(node):
if
i
.
owner
and
i
.
owner
.
op
==
T
.
switch
:
switch
=
i
.
owner
try
:
if
get_constant_value
(
switch
.
inputs
[
1
])
==
0.
:
if
get_
scalar_
constant_value
(
switch
.
inputs
[
1
])
==
0.
:
listmul
=
node
.
inputs
[:
idx
]
+
node
.
inputs
[
idx
+
1
:]
fct
=
[
T
.
switch
(
switch
.
inputs
[
0
],
0
,
T
.
mul
(
*
(
listmul
+
[
switch
.
inputs
[
2
]])))]
fct
[
0
]
.
values_eq_approx
=
fct
[
0
]
.
type
.
values_eq_approx_remove_nan
return
fct
except
Type
Error
:
except
NotScalarConstant
Error
:
pass
try
:
if
get_constant_value
(
switch
.
inputs
[
2
])
==
0.
:
if
get_
scalar_
constant_value
(
switch
.
inputs
[
2
])
==
0.
:
listmul
=
node
.
inputs
[:
idx
]
+
node
.
inputs
[
idx
+
1
:]
fct
=
[
T
.
switch
(
switch
.
inputs
[
0
],
T
.
mul
(
*
(
listmul
+
[
switch
.
inputs
[
1
]])),
0
)]
fct
[
0
]
.
values_eq_approx
=
fct
[
0
]
.
type
.
values_eq_approx_remove_nan
return
fct
except
Type
Error
:
except
NotScalarConstant
Error
:
pass
return
False
...
...
@@ -2295,22 +2301,22 @@ def local_div_switch_sink(node):
if
node
.
inputs
[
0
]
.
owner
and
node
.
inputs
[
0
]
.
owner
.
op
==
T
.
switch
:
switch
=
node
.
inputs
[
0
]
.
owner
try
:
if
get_constant_value
(
switch
.
inputs
[
1
])
==
0.
:
if
get_
scalar_
constant_value
(
switch
.
inputs
[
1
])
==
0.
:
fct
=
[
T
.
switch
(
switch
.
inputs
[
0
],
0
,
op
(
switch
.
inputs
[
2
],
node
.
inputs
[
1
]))]
fct
[
0
]
.
values_eq_approx
=
fct
[
0
]
.
type
.
values_eq_approx_remove_nan
return
fct
except
Type
Error
:
except
NotScalarConstant
Error
:
pass
try
:
if
get_constant_value
(
switch
.
inputs
[
2
])
==
0.
:
if
get_
scalar_
constant_value
(
switch
.
inputs
[
2
])
==
0.
:
fct
=
[
T
.
switch
(
switch
.
inputs
[
0
],
op
(
switch
.
inputs
[
1
],
node
.
inputs
[
1
]),
0
)]
fct
[
0
]
.
values_eq_approx
=
fct
[
0
]
.
type
.
values_eq_approx_remove_nan
return
fct
except
Type
Error
:
except
NotScalarConstant
Error
:
pass
return
False
...
...
@@ -2375,7 +2381,7 @@ if 0:
def
tmp
(
thing
):
try
:
return
T
.
get_constant_value
(
thing
)
return
T
.
get_
scalar_
constant_value
(
thing
)
except
(
TypeError
,
ValueError
),
e
:
print
e
,
thing
.
owner
.
inputs
[
0
]
return
None
...
...
@@ -2702,8 +2708,8 @@ class Canonizer(gof.LocalOptimizer):
"""
if
isinstance
(
v
,
Variable
):
try
:
return
get_constant_value
(
v
)
except
Type
Error
:
return
get_
scalar_
constant_value
(
v
)
except
NotScalarConstant
Error
:
return
None
else
:
return
v
...
...
@@ -3204,15 +3210,15 @@ def local_sum_alloc(node):
if
(
node
.
op
.
axis
is
None
or
node
.
op
.
axis
==
tuple
(
range
(
input
.
ndim
))):
try
:
val
=
get_constant_value
(
input
)
val
=
get_
scalar_
constant_value
(
input
)
assert
val
.
size
==
1
val
=
val
.
reshape
(
1
)[
0
]
*
T
.
mul
(
*
shapes
)
return
[
T
.
cast
(
val
,
dtype
=
node
.
outputs
[
0
]
.
dtype
)]
except
Type
Error
:
except
NotScalarConstant
Error
:
pass
else
:
try
:
val
=
get_constant_value
(
input
)
val
=
get_
scalar_
constant_value
(
input
)
assert
val
.
size
==
1
val
=
val
.
reshape
(
1
)[
0
]
to_prod
=
[
shapes
[
i
]
for
i
in
xrange
(
len
(
shapes
))
...
...
@@ -3222,7 +3228,7 @@ def local_sum_alloc(node):
return
[
T
.
alloc
(
T
.
cast
(
val
,
dtype
=
node
.
outputs
[
0
]
.
dtype
),
*
[
shapes
[
i
]
for
i
in
xrange
(
len
(
shapes
))
if
i
not
in
node
.
op
.
axis
])]
except
Type
Error
:
except
NotScalarConstant
Error
:
pass
...
...
@@ -3282,8 +3288,8 @@ def local_mul_zero(node):
for
i
in
node
.
inputs
:
try
:
value
=
get_constant_value
(
i
)
except
Type
Error
:
value
=
get_
scalar_
constant_value
(
i
)
except
NotScalarConstant
Error
:
continue
#print 'MUL by value', value, node.inputs
if
N
.
all
(
value
==
0
):
...
...
@@ -3520,8 +3526,8 @@ def local_add_specialize(node):
new_inputs
=
[]
for
input
in
node
.
inputs
:
try
:
y
=
get_constant_value
(
input
)
except
Type
Error
:
y
=
get_
scalar_
constant_value
(
input
)
except
NotScalarConstant
Error
:
y
=
input
if
numpy
.
all
(
y
==
0.0
):
continue
...
...
@@ -3614,7 +3620,7 @@ def local_abs_merge(node):
if
i
.
owner
and
i
.
owner
.
op
==
T
.
abs_
:
inputs
.
append
(
i
.
owner
.
inputs
[
0
])
else
:
const
=
get_constant_value
(
i
)
const
=
get_
scalar_
constant_value
(
i
)
if
not
(
const
>=
0
)
.
all
():
return
False
inputs
.
append
(
i
)
...
...
@@ -3880,9 +3886,9 @@ def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1
"""
try
:
v
=
get_constant_value
(
expr
)
v
=
get_
scalar_
constant_value
(
expr
)
return
numpy
.
allclose
(
v
,
1
)
except
Type
Error
:
except
NotScalarConstant
Error
:
return
False
...
...
@@ -3890,9 +3896,9 @@ def _is_minus1(expr):
"""rtype bool. True iff expr is a constant close to -1
"""
try
:
v
=
get_constant_value
(
expr
)
v
=
get_
scalar_
constant_value
(
expr
)
return
numpy
.
allclose
(
v
,
-
1
)
except
Type
Error
:
except
NotScalarConstant
Error
:
return
False
#1+erf(x)=>erfc(-x)
...
...
@@ -4132,8 +4138,8 @@ def local_grad_log_erfc_neg(node):
mul_neg
=
T
.
mul
(
*
mul_inputs
)
try
:
cst2
=
get_constant_value
(
mul_neg
.
owner
.
inputs
[
0
])
except
Type
Error
:
cst2
=
get_
scalar_
constant_value
(
mul_neg
.
owner
.
inputs
[
0
])
except
NotScalarConstant
Error
:
return
False
if
len
(
mul_neg
.
owner
.
inputs
)
==
2
:
...
...
@@ -4159,8 +4165,8 @@ def local_grad_log_erfc_neg(node):
x
=
erfc_x
try
:
cst
=
get_constant_value
(
erfc_x
.
owner
.
inputs
[
0
])
except
Type
Error
:
cst
=
get_
scalar_
constant_value
(
erfc_x
.
owner
.
inputs
[
0
])
except
NotScalarConstant
Error
:
return
False
if
cst2
!=
-
cst
*
2
:
return
False
...
...
theano/tensor/opt_uncanonicalize.py
浏览文件 @
483cb9d3
...
...
@@ -41,7 +41,7 @@ from theano.gof.python25 import any, all
from
theano.gof.opt
import
Optimizer
from
theano.gof
import
InconsistencyError
,
toolbox
from
basic
import
get_
constant_value
from
basic
import
get_
scalar_constant_value
,
NotScalarConstantError
from
theano.tensor.opt
import
register_uncanonicalize
from
theano
import
scalar
as
scal
...
...
@@ -64,8 +64,8 @@ class MaxAndArgmaxOptimizer(Optimizer):
if
node
.
op
==
T
.
_max_and_argmax
:
if
len
(
node
.
outputs
[
1
]
.
clients
)
==
0
:
try
:
axis
=
get_constant_value
(
node
.
inputs
[
1
])
except
(
ValueError
,
TypeError
),
e
:
axis
=
get_
scalar_
constant_value
(
node
.
inputs
[
1
])
except
NotScalarConstantError
:
return
False
new
=
CAReduce
(
scal
.
maximum
,
axis
)(
node
.
inputs
[
0
])
...
...
theano/tensor/tests/test_basic.py
浏览文件 @
483cb9d3
...
...
@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
Reshape
,
row
,
scalar
,
scalars
,
second
,
smallest
,
stack
,
sub
,
Tensor
,
tensor_copy
,
tensordot
,
tensordot_grad
,
TensorType
,
unbroadcast
,
var
,
Join
,
shape
,
MaxAndArgmax
,
lscalar
,
zvector
,
exp
,
get_constant_value
,
ivector
,
reshape
,
scalar_from_tensor
,
scal
,
get_
scalar_
constant_value
,
ivector
,
reshape
,
scalar_from_tensor
,
scal
,
iscalars
,
arange
,
dscalars
,
fvector
,
imatrix
,
numeric_grad
,
opt
,
ComplexError
,
TensorDot
,
lvector
,
true_div
,
max
,
min
,
Split
,
roll
,
tile
,
patternbroadcast
,
Eye
,
Shape
,
Default
,
Dot
,
PermuteRowElements
,
...
...
@@ -2140,7 +2140,7 @@ class T_max_and_argmax(unittest.TestCase):
cost
=
argmax
(
x
,
axis
=
0
)
.
sum
()
value_error_raised
=
False
gx
=
grad
(
cost
,
x
)
val
=
tensor
.
get_constant_value
(
gx
)
val
=
tensor
.
get_
scalar_
constant_value
(
gx
)
assert
val
==
0.0
def
test_grad
(
self
):
...
...
@@ -6167,40 +6167,40 @@ def test_dimshuffle_duplicate():
assert
success
class
T_get_constant_value
(
unittest
.
TestCase
):
def
test_get_constant_value
(
self
):
class
T_get_
scalar_
constant_value
(
unittest
.
TestCase
):
def
test_get_
scalar_
constant_value
(
self
):
a
=
tensor
.
stack
(
1
,
2
,
3
)
assert
get_constant_value
(
a
[
0
])
==
1
assert
get_constant_value
(
a
[
1
])
==
2
assert
get_constant_value
(
a
[
2
])
==
3
assert
get_
scalar_
constant_value
(
a
[
0
])
==
1
assert
get_
scalar_
constant_value
(
a
[
1
])
==
2
assert
get_
scalar_
constant_value
(
a
[
2
])
==
3
b
=
tensor
.
iscalar
()
a
=
tensor
.
stack
(
b
,
2
,
3
)
self
.
assertRaises
(
TypeError
,
get
_constant_value
,
a
[
0
])
assert
get_constant_value
(
a
[
1
])
==
2
assert
get_constant_value
(
a
[
2
])
==
3
self
.
assertRaises
(
tensor
.
basic
.
NotScalarConstantError
,
get_scalar
_constant_value
,
a
[
0
])
assert
get_
scalar_
constant_value
(
a
[
1
])
==
2
assert
get_
scalar_
constant_value
(
a
[
2
])
==
3
# For now get_constant_value goes through only MakeVector and Join of
# For now get_
scalar_
constant_value goes through only MakeVector and Join of
# scalars.
v
=
tensor
.
ivector
()
a
=
tensor
.
stack
(
v
,
2
,
3
)
self
.
assertRaises
(
TypeError
,
get
_constant_value
,
a
[
0
])
self
.
assertRaises
(
TypeError
,
get
_constant_value
,
a
[
1
])
self
.
assertRaises
(
TypeError
,
get
_constant_value
,
a
[
2
])
self
.
assertRaises
(
tensor
.
NotScalarConstantError
,
get_scalar
_constant_value
,
a
[
0
])
self
.
assertRaises
(
tensor
.
NotScalarConstantError
,
get_scalar
_constant_value
,
a
[
1
])
self
.
assertRaises
(
tensor
.
NotScalarConstantError
,
get_scalar
_constant_value
,
a
[
2
])
# Test the case SubTensor(Shape(v)) when the dimensions
# is broadcastable.
v
=
tensor
.
row
()
assert
get_constant_value
(
v
.
shape
[
0
])
==
1
assert
get_
scalar_
constant_value
(
v
.
shape
[
0
])
==
1
def
test_subtensor_of_constant
(
self
):
c
=
constant
(
rand
(
5
))
for
i
in
range
(
c
.
value
.
shape
[
0
]):
assert
get_constant_value
(
c
[
i
])
==
c
.
value
[
i
]
assert
get_
scalar_
constant_value
(
c
[
i
])
==
c
.
value
[
i
]
c
=
constant
(
rand
(
5
,
5
))
for
i
in
range
(
c
.
value
.
shape
[
0
]):
for
j
in
range
(
c
.
value
.
shape
[
1
]):
assert
get_constant_value
(
c
[
i
,
j
])
==
c
.
value
[
i
,
j
]
assert
get_
scalar_
constant_value
(
c
[
i
,
j
])
==
c
.
value
[
i
,
j
]
class
T_as_tensor_variable
(
unittest
.
TestCase
):
...
...
theano/tensor/tests/test_elemwise.py
浏览文件 @
483cb9d3
...
...
@@ -856,7 +856,7 @@ def test_gt_grad():
"""A user test that failed.
Something about it made Elemwise.grad return something that was
too complicated for get_constant_value to recognize as being 0, so
too complicated for get_
scalar_
constant_value to recognize as being 0, so
gradient.grad reported that it was not a valid gradient of an
integer.
...
...
theano/tests/test_tutorial.py
浏览文件 @
483cb9d3
...
...
@@ -936,7 +936,7 @@ class T_fibby(unittest.TestCase):
if
node
.
op
==
fibby
:
x
=
node
.
inputs
[
0
]
try
:
if
numpy
.
all
(
0
==
get_constant_value
(
x
)):
if
numpy
.
all
(
0
==
get_
scalar_
constant_value
(
x
)):
return
[
x
]
except
TypeError
:
pass
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论