Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6a2259f4
提交
6a2259f4
authored
10月 11, 2020
作者:
Brandon T. Willard
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove constant caching
Closes #99
上级
b99dea2f
隐藏空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
137 行增加
和
252 行删除
+137
-252
test_fg.py
tests/gof/test_fg.py
+4
-14
test_graph.py
tests/gof/test_graph.py
+0
-13
test_scan.py
tests/scan_module/test_scan.py
+54
-54
test_basic.py
tests/tensor/test_basic.py
+40
-67
test_printing.py
tests/test_printing.py
+36
-36
__init__.py
theano/gof/__init__.py
+0
-1
fg.py
theano/gof/fg.py
+0
-19
opt.py
theano/gof/opt.py
+1
-6
basic.py
theano/tensor/basic.py
+2
-42
没有找到文件。
tests/gof/test_fg.py
浏览文件 @
6a2259f4
...
...
@@ -5,25 +5,14 @@ import pytest
import
theano
from
theano.compat
import
PY3
from
theano.gof
import
CachedConstantError
,
FunctionGraph
from
theano.gof
.fg
import
FunctionGraph
from
theano
import
tensor
as
tt
class
TFunctionGraph
:
def
test_constant_cache_error
(
self
):
v
=
theano
.
tensor
.
constant
(
1
)
assert
v
.
cached
with
pytest
.
raises
(
CachedConstantError
):
FunctionGraph
([],
[
v
+
1
],
clone
=
False
)
def
test_clone
(
self
):
v
=
theano
.
tensor
.
constant
(
1
)
assert
v
.
cached
FunctionGraph
([],
[
v
+
1
])
class
TestFunctionGraph
:
def
test_pickle
(
self
):
v
=
tt
.
vector
()
func
=
theano
.
gof
.
FunctionGraph
([
v
],
[
v
+
1
])
func
=
FunctionGraph
([
v
],
[
v
+
1
])
s
=
pickle
.
dumps
(
func
)
pickle
.
loads
(
s
)
...
...
@@ -31,6 +20,7 @@ class TFunctionGraph:
@pytest.mark.skipif
(
not
theano
.
config
.
cxx
,
reason
=
"G++ not available, so we need to skip this test."
)
@pytest.mark.slow
def
test_node_outputs_not_used
(
self
):
# In the past, we where removing some not used variable from
# fgraph.variables event if the apply had other output used in
...
...
tests/gof/test_graph.py
浏览文件 @
6a2259f4
...
...
@@ -266,27 +266,14 @@ class TestAutoName:
assert
r2
.
auto_name
==
"auto_"
+
str
(
autoname_id
+
1
)
def
test_constant
(
self
):
# Make sure the value we will use for the test aren't yet in the cache.
r1
=
tensor
.
constant
(
1.5
)
del
tensor
.
constant_cache
[
r1
.
signature
()]
r1
=
tensor
.
constant
(
1.6
)
del
tensor
.
constant_cache
[
r1
.
signature
()]
# Get counter value
autoname_id
=
next
(
Variable
.
__count__
)
Variable
.
__count__
=
count
(
autoname_id
)
r1
=
tensor
.
constant
(
1.5
)
r2
=
tensor
.
constant
(
1.5
)
assert
r1
.
auto_name
==
"auto_"
+
str
(
autoname_id
),
(
r1
.
auto_name
,
"auto_"
+
str
(
autoname_id
),
)
# We reuse the same variable
assert
r2
.
auto_name
==
"auto_"
+
str
(
autoname_id
),
(
r2
.
auto_name
,
"auto_"
+
str
(
autoname_id
),
)
assert
r1
is
r2
r3
=
tensor
.
constant
(
1.6
)
assert
r3
.
auto_name
==
"auto_"
+
str
(
autoname_id
+
1
)
...
...
tests/scan_module/test_scan.py
浏览文件 @
6a2259f4
...
...
@@ -1673,72 +1673,72 @@ class TestScan:
| |<RandomStateType> [id DD]
| |Shape [id DE] ''
| | |Subtensor{int64::} [id DA] ''
| |TensorConstant{0.1} [id
CW
]
| |TensorConstant{0.9} [id
CX
]
|Sum{acc_dtype=float64} [id D
F
] ''
|Elemwise{mul,no_inplace} [id D
G
] ''
| |TensorConstant{0.1} [id
DF
]
| |TensorConstant{0.9} [id
DG
]
|Sum{acc_dtype=float64} [id D
H
] ''
|Elemwise{mul,no_inplace} [id D
I
] ''
|for{cpu,scan_fn}.2 [id H] ''
|RandomFunction{uniform}.1 [id D
H
] ''
|<RandomStateType> [id D
I
]
|Shape [id D
J
] ''
|RandomFunction{uniform}.1 [id D
J
] ''
|<RandomStateType> [id D
K
]
|Shape [id D
L
] ''
| |for{cpu,scan_fn}.2 [id H] ''
|TensorConstant{0.1} [id
CW
]
|TensorConstant{0.9} [id
CX
]
|TensorConstant{0.1} [id
DM
]
|TensorConstant{0.9} [id
DN
]
Inner graphs of the scan ops:
for{cpu,scan_fn}.1 [id H] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id D
K
] ''
> |y0[t-1] [id D
L
] -> [id BR]
> |y0[t-3] [id D
M
] -> [id BR]
> |InplaceDimShuffle{} [id D
N
] ''
> |CGemv{inplace} [id D
O
] ''
> |AllocEmpty{dtype='
%(float)
s'} [id D
P
] ''
> | |TensorConstant{1} [id D
Q
]
> |TensorConstant{1.0} [id D
R
]
> |InplaceDimShuffle{x,0} [id D
S
] ''
> | |wout_copy [id D
T
] -> [id CQ]
> |x0[t-1] [id D
U
] -> [id CB]
> |TensorConstant{0.0} [id D
V
]
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id
DW
] ''
> |CGemv{no_inplace} [id
DX
] ''
> | |AllocEmpty{dtype='
%(float)
s'} [id
DY
] ''
> | | |Shape_i{1} [id
DZ
] ''
> | | |win_copy [id E
A
] -> [id CR]
> | |TensorConstant{1.0} [id D
R
]
> | |InplaceDimShuffle{1,0} [id E
B
] 'win_copy.T'
> | | |win_copy [id E
A
] -> [id CR]
> | |u1[t] [id E
C
] -> [id BJ]
> | |TensorConstant{0.0} [id D
V
]
> |u2[t] [id E
D
] -> [id BN]
> |u2[t-1] [id E
E
] -> [id BL]
> |u2[t+1] [id E
F
] -> [id BP]
> |win2_copy [id E
G
] -> [id CO]
> |CGemv{inplace} [id E
H
] ''
> |AllocEmpty{dtype='
%(float)
s'} [id E
I
] ''
> | |Shape_i{1} [id E
J
] ''
> | |w_copy [id E
K
] -> [id CP]
> |TensorConstant{1.0} [id D
R
]
> |InplaceDimShuffle{1,0} [id E
L
] 'w_copy.T'
> | |w_copy [id E
K
] -> [id CP]
> |x0[t-1] [id D
U
] -> [id CB]
> |TensorConstant{0.0} [id D
V
]
>CGemv{no_inplace} [id
DX
] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id D
O
] ''
> |y0[t-1] [id D
P
] -> [id BR]
> |y0[t-3] [id D
Q
] -> [id BR]
> |InplaceDimShuffle{} [id D
R
] ''
> |CGemv{inplace} [id D
S
] ''
> |AllocEmpty{dtype='
%(float)
s'} [id D
T
] ''
> | |TensorConstant{1} [id D
U
]
> |TensorConstant{1.0} [id D
V
]
> |InplaceDimShuffle{x,0} [id D
W
] ''
> | |wout_copy [id D
X
] -> [id CQ]
> |x0[t-1] [id D
Y
] -> [id CB]
> |TensorConstant{0.0} [id D
Z
]
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id
EA
] ''
> |CGemv{no_inplace} [id
EB
] ''
> | |AllocEmpty{dtype='
%(float)
s'} [id
EC
] ''
> | | |Shape_i{1} [id
ED
] ''
> | | |win_copy [id E
E
] -> [id CR]
> | |TensorConstant{1.0} [id D
V
]
> | |InplaceDimShuffle{1,0} [id E
F
] 'win_copy.T'
> | | |win_copy [id E
E
] -> [id CR]
> | |u1[t] [id E
G
] -> [id BJ]
> | |TensorConstant{0.0} [id D
Z
]
> |u2[t] [id E
H
] -> [id BN]
> |u2[t-1] [id E
I
] -> [id BL]
> |u2[t+1] [id E
J
] -> [id BP]
> |win2_copy [id E
K
] -> [id CO]
> |CGemv{inplace} [id E
L
] ''
> |AllocEmpty{dtype='
%(float)
s'} [id E
M
] ''
> | |Shape_i{1} [id E
N
] ''
> | |w_copy [id E
O
] -> [id CP]
> |TensorConstant{1.0} [id D
V
]
> |InplaceDimShuffle{1,0} [id E
P
] 'w_copy.T'
> | |w_copy [id E
O
] -> [id CP]
> |x0[t-1] [id D
Y
] -> [id CB]
> |TensorConstant{0.0} [id D
Z
]
>CGemv{no_inplace} [id
EB
] ''
for{cpu,scan_fn}.0 [id H] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id D
K
] ''
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id
DW
] ''
>CGemv{no_inplace} [id
DX
] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id D
O
] ''
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id
EA
] ''
>CGemv{no_inplace} [id
EB
] ''
for{cpu,scan_fn}.2 [id H] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id D
K
] ''
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id
DW
] ''
>CGemv{no_inplace} [id
DX
] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id D
O
] ''
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id
EA
] ''
>CGemv{no_inplace} [id
EB
] ''
for{cpu,scan_fn}.2 [id H] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id D
K
] ''
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id
DW
] ''
>CGemv{no_inplace} [id
DX
] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id D
O
] ''
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id
EA
] ''
>CGemv{no_inplace} [id
EB
] ''
"""
%
{
"float"
:
theano
.
config
.
floatX
}
...
...
tests/tensor/test_basic.py
浏览文件 @
6a2259f4
...
...
@@ -2761,6 +2761,10 @@ class TestAsTensorVariable:
def
setup_method
(
self
):
self
.
x
=
tensor
.
scalar
(
"x"
)
def
test_tensor_from_scalar
(
self
):
y
=
as_tensor_variable
(
scal
.
int8
())
assert
isinstance
(
y
.
owner
.
op
,
TensorFromScalar
)
def
test_one_output
(
self
):
good_apply_var
=
ApplyDefaultTestOp
(
0
)
.
make_node
(
self
.
x
)
as_tensor_variable
(
good_apply_var
)
...
...
@@ -5747,81 +5751,50 @@ class TestDot:
assert
g
.
broadcastable
==
y
.
broadcastable
class
TestTensorfromscalar
:
def
test_basic
(
self
):
s
=
scal
.
constant
(
56
)
t
=
tensor_from_scalar
(
s
)
assert
t
.
owner
.
op
is
tensor_from_scalar
assert
t
.
type
.
broadcastable
==
(),
t
.
type
.
broadcastable
assert
t
.
type
.
ndim
==
0
,
t
.
type
.
ndim
assert
t
.
type
.
dtype
==
s
.
type
.
dtype
v
=
eval_outputs
([
t
])
assert
v
==
56
,
v
assert
isinstance
(
v
,
np
.
ndarray
)
assert
v
.
shape
==
(),
v
.
shape
def
test_basic_1
(
self
):
s
=
scal
.
constant
(
56
)
t
=
as_tensor_variable
(
s
)
assert
t
.
owner
.
op
is
tensor_from_scalar
assert
t
.
type
.
broadcastable
==
(),
t
.
type
.
broadcastable
assert
t
.
type
.
ndim
==
0
,
t
.
type
.
ndim
assert
t
.
type
.
dtype
==
s
.
type
.
dtype
v
=
eval_outputs
([
t
])
assert
v
==
56
,
v
assert
isinstance
(
v
,
np
.
ndarray
)
assert
v
.
shape
==
(),
v
.
shape
def
test_TensorFromScalar
():
s
=
scal
.
constant
(
56
)
t
=
tensor_from_scalar
(
s
)
assert
t
.
owner
.
op
is
tensor_from_scalar
assert
t
.
type
.
broadcastable
==
(),
t
.
type
.
broadcastable
assert
t
.
type
.
ndim
==
0
,
t
.
type
.
ndim
assert
t
.
type
.
dtype
==
s
.
type
.
dtype
g
=
grad
(
t
,
s
)
assert
eval_outputs
([
g
])
==
0.0
v
=
eval_outputs
([
t
])
def
test_basic_2
(
self
):
s
=
scal
.
constant
(
56.0
)
t
=
as_tensor_variable
(
s
)
assert
t
.
owner
.
op
is
tensor_from_scalar
assert
t
.
type
.
broadcastable
==
(),
t
.
type
.
broadcastable
assert
t
.
type
.
ndim
==
0
,
t
.
type
.
ndim
assert
t
.
type
.
dtype
==
s
.
type
.
dtype
v
=
eval_outputs
([
t
])
assert
v
==
56
,
v
assert
isinstance
(
v
,
np
.
ndarray
)
assert
v
.
shape
==
(),
v
.
shape
assert
v
==
56.0
,
v
assert
isinstance
(
v
,
np
.
ndarray
)
assert
v
.
shape
==
(),
v
.
shape
g
=
grad
(
t
,
s
)
assert
eval_outputs
([
g
])
==
0.0
g
=
grad
(
t
,
s
)
assert
eval_outputs
([
g
])
==
1.0
def
test_ScalarFromTensor
():
tt
=
constant
(
56
)
# scal.constant(56)
ss
=
scalar_from_tensor
(
tt
)
assert
ss
.
owner
.
op
is
scalar_from_tensor
assert
ss
.
type
.
dtype
==
tt
.
type
.
dtype
class
TestScalarfromtensor
:
def
test_basic
(
self
):
tt
=
constant
(
56
)
# scal.constant(56)
ss
=
scalar_from_tensor
(
tt
)
assert
ss
.
owner
.
op
is
scalar_from_tensor
assert
ss
.
type
.
dtype
==
tt
.
type
.
dtype
v
=
eval_outputs
([
ss
])
v
=
eval_outputs
([
ss
])
assert
v
==
56
assert
v
.
shape
==
()
assert
v
==
56
if
config
.
cast_policy
==
"custom"
:
assert
isinstance
(
v
,
np
.
int8
)
elif
config
.
cast_policy
in
(
"numpy"
,
"numpy+floatX"
):
assert
isinstance
(
v
,
str
(
np
.
asarray
(
56
)
.
dtype
))
else
:
raise
NotImplementedError
(
config
.
cast_policy
)
assert
v
.
shape
==
()
tt
=
lscalar
()
ss
=
scalar_from_tensor
(
tt
)
ss
.
owner
.
op
.
grad
([
tt
],
[
ss
])
fff
=
function
([
tt
],
ss
)
v
=
fff
(
np
.
asarray
(
5
))
assert
v
==
5
assert
isinstance
(
v
,
np
.
int64
)
assert
v
.
shape
==
()
if
config
.
cast_policy
==
"custom"
:
assert
isinstance
(
v
,
np
.
int8
)
elif
config
.
cast_policy
in
(
"numpy"
,
"numpy+floatX"
):
assert
isinstance
(
v
,
str
(
np
.
asarray
(
56
)
.
dtype
))
else
:
raise
NotImplementedError
(
config
.
cast_policy
)
tt
=
lscalar
()
ss
=
scalar_from_tensor
(
tt
)
ss
.
owner
.
op
.
grad
([
tt
],
[
ss
])
fff
=
function
([
tt
],
ss
)
v
=
fff
(
np
.
asarray
(
5
))
assert
v
==
5
assert
isinstance
(
v
,
np
.
int64
)
assert
v
.
shape
==
()
class
TestGrad
:
...
...
tests/test_printing.py
浏览文件 @
6a2259f4
...
...
@@ -656,61 +656,61 @@ def test_scan_debugprint5():
| | | | | | |for{cpu,scan_fn} [id F] ''
| | | | | | |Constant{1} [id BT]
| | | | | |InplaceDimShuffle{x,x} [id BU] ''
| | | | | |TensorConstant{0.0} [id B
P
]
| | | | |Elemwise{second} [id B
V
] ''
| | | | | |Subtensor{int64} [id B
W
] ''
| | | | | |TensorConstant{0.0} [id B
V
]
| | | | |Elemwise{second} [id B
W
] ''
| | | | | |Subtensor{int64} [id B
X
] ''
| | | | | | |Subtensor{int64::} [id BS] ''
| | | | | | |Constant{-1} [id B
X
]
| | | | | |InplaceDimShuffle{x} [id B
Y
] ''
| | | | | |Elemwise{second,no_inplace} [id
BZ
] ''
| | | | | |Sum{acc_dtype=float64} [id C
A
] ''
| | | | | | |Subtensor{int64} [id B
W
] ''
| | | | | |TensorConstant{1.0} [id
R
]
| | | | |Constant{-1} [id B
X
]
| | | | | | |Constant{-1} [id B
Y
]
| | | | | |InplaceDimShuffle{x} [id B
Z
] ''
| | | | | |Elemwise{second,no_inplace} [id
CA
] ''
| | | | | |Sum{acc_dtype=float64} [id C
B
] ''
| | | | | | |Subtensor{int64} [id B
X
] ''
| | | | | |TensorConstant{1.0} [id
CC
]
| | | | |Constant{-1} [id B
Y
]
| | | |Constant{1} [id BT]
| | |Constant{-1} [id C
B
]
| |Alloc [id C
C
] ''
| | |TensorConstant{0.0} [id
BP
]
| | |Elemwise{add,no_inplace} [id C
D
] ''
| | |Constant{-1} [id C
D
]
| |Alloc [id C
E
] ''
| | |TensorConstant{0.0} [id
CF
]
| | |Elemwise{add,no_inplace} [id C
G
] ''
| | | |Elemwise{sub,no_inplace} [id C] ''
| | | |TensorConstant{1} [id
Y
]
| | |Subtensor{int64} [id C
E
] ''
| | |Shape [id C
F
] ''
| | | |TensorConstant{1} [id
CH
]
| | |Subtensor{int64} [id C
I
] ''
| | |Shape [id C
J
] ''
| | | |A [id P]
| | |Constant{0} [id C
G
]
| | |Constant{0} [id C
K
]
| |A [id P]
|Constant{-1} [id C
H
]
|Constant{-1} [id C
L
]
Inner graphs of the scan ops:
for{cpu,grad_of_scan_fn}.1 [id B] ''
>Elemwise{add,no_inplace} [id C
I
] ''
> |Elemwise{mul} [id C
J
] ''
> | |<TensorType(float64, vector)> [id C
K
] -> [id BL]
> | |A_copy [id C
L
] -> [id P]
> |<TensorType(float64, vector)> [id C
M
] -> [id BL]
>Elemwise{add,no_inplace} [id C
N
] ''
> |Elemwise{mul} [id C
O
] ''
> | |<TensorType(float64, vector)> [id C
K
] -> [id BL]
> | |<TensorType(float64, vector)> [id C
P
] -> [id Z]
> |<TensorType(float64, vector)> [id C
Q] -> [id CC
]
>Elemwise{add,no_inplace} [id C
M
] ''
> |Elemwise{mul} [id C
N
] ''
> | |<TensorType(float64, vector)> [id C
O
] -> [id BL]
> | |A_copy [id C
P
] -> [id P]
> |<TensorType(float64, vector)> [id C
Q
] -> [id BL]
>Elemwise{add,no_inplace} [id C
R
] ''
> |Elemwise{mul} [id C
S
] ''
> | |<TensorType(float64, vector)> [id C
O
] -> [id BL]
> | |<TensorType(float64, vector)> [id C
T
] -> [id Z]
> |<TensorType(float64, vector)> [id C
U] -> [id CE
]
for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id C
R
] ''
> |<TensorType(float64, vector)> [id C
P
] -> [id H]
> |A_copy [id C
L
] -> [id P]
>Elemwise{mul,no_inplace} [id C
V
] ''
> |<TensorType(float64, vector)> [id C
T
] -> [id H]
> |A_copy [id C
P
] -> [id P]
for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id C
R
] ''
>Elemwise{mul,no_inplace} [id C
V
] ''
for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id C
R
] ''
>Elemwise{mul,no_inplace} [id C
V
] ''
for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id C
R
] ''
>Elemwise{mul,no_inplace} [id C
V
] ''
for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id C
R
] ''"""
>Elemwise{mul,no_inplace} [id C
V
] ''"""
for
truth
,
out
in
zip
(
expected_output
.
split
(
"
\n
"
),
lines
):
assert
truth
.
strip
()
==
out
.
strip
()
...
...
theano/gof/__init__.py
浏览文件 @
6a2259f4
...
...
@@ -40,7 +40,6 @@ e-mail thread "What is gof?".
from
theano.gof.cc
import
CLinker
,
OpWiseCLinker
,
DualLinker
,
HideC
from
theano.gof.fg
import
(
CachedConstantError
,
InconsistencyError
,
MissingInputError
,
FunctionGraph
,
...
...
theano/gof/fg.py
浏览文件 @
6a2259f4
...
...
@@ -20,17 +20,6 @@ from theano.misc.ordered_set import OrderedSet
NullType
=
None
class
CachedConstantError
(
Exception
):
"""
An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
cached constant in other FunctionGraph.
"""
pass
class
InconsistencyError
(
Exception
):
"""
This exception should be thrown by listeners to FunctionGraph when the
...
...
@@ -186,15 +175,7 @@ class FunctionGraph(utils.object2):
self
.
__setup_r__
(
input
)
self
.
variables
.
add
(
input
)
# Setup a Variable #
def
__setup_r__
(
self
,
r
):
# sets up r so it belongs to this fgraph
if
getattr
(
r
,
"cached"
,
False
):
raise
CachedConstantError
(
"You manually constructed a FunctionGraph, but you passed it a"
" graph that has a cached constant. This should not happen."
" Clone the graph before building the FunctionGraph."
)
if
hasattr
(
r
,
"fgraph"
)
and
r
.
fgraph
is
not
None
and
r
.
fgraph
is
not
self
:
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
r
.
fgraph
=
self
...
...
theano/gof/opt.py
浏览文件 @
6a2259f4
...
...
@@ -93,12 +93,7 @@ class Optimizer(object):
"""
self
.
add_requirements
(
fgraph
)
try
:
orig
=
theano
.
tensor
.
basic
.
constant
.
enable
theano
.
tensor
.
basic
.
constant
.
enable
=
False
ret
=
self
.
apply
(
fgraph
,
*
args
,
**
kwargs
)
finally
:
theano
.
tensor
.
basic
.
constant
.
enable
=
orig
ret
=
self
.
apply
(
fgraph
,
*
args
,
**
kwargs
)
return
ret
def
__call__
(
self
,
fgraph
):
...
...
theano/tensor/basic.py
浏览文件 @
6a2259f4
...
...
@@ -26,7 +26,6 @@ from theano.tensor.var import (
AsTensorError
,
TensorVariable
,
TensorConstant
,
TensorConstantSignature
,
_tensor_py_operators
,
)
from
theano.tensor.type
import
TensorType
,
values_eq_approx_always_true
...
...
@@ -217,7 +216,7 @@ as_tensor = as_tensor_variable
def
constant
(
x
,
name
=
None
,
ndim
=
None
,
dtype
=
None
):
"""Return a
symbolic `
Constant` with value `x`.
"""Return a
`Tensor
Constant` with value `x`.
Raises
------
...
...
@@ -226,16 +225,6 @@ def constant(x, name=None, ndim=None, dtype=None):
ValueError
`x` could not be expanded to have ndim dimensions.
Notes
-----
We create a small cache of frequently used constant.
This speed up the Merge optimization for big graph.
We want to cache all scalar to don't merge as frequently constants.
But we don't want to cache too much stuff.
So we cache integer with dtype [u]int and float where the value is
between -10 and 10.
We cache all broadcast pattern for scalar.
"""
x_
=
scal
.
convert
(
x
,
dtype
=
dtype
)
...
...
@@ -252,40 +241,11 @@ def constant(x, name=None, ndim=None, dtype=None):
try
:
ttype
=
TensorType
(
dtype
=
x_
.
dtype
,
broadcastable
=
bcastable
)
if
not
constant
.
enable
:
return
TensorConstant
(
ttype
,
x_
,
name
=
name
)
sig
=
TensorConstantSignature
((
ttype
,
x_
))
if
sig
in
constant_cache
:
return
constant_cache
[
sig
]
ret
=
TensorConstant
(
ttype
,
x_
,
name
=
name
)
if
(
x_
.
size
==
1
and
(
-
10
)
<=
x_
<=
10
and
(
x_
.
dtype
in
int_dtypes
or
x_
.
dtype
in
uint_dtypes
or
(
x_
.
dtype
in
float_dtypes
and
# Limit the size of the cache.
len
(
constant_cache
)
<
10000
)
)
):
constant_cache
[
sig
]
=
ret
# This is needed to raise a good error to the user.
ret
.
cached
=
True
return
ret
return
TensorConstant
(
ttype
,
x_
,
name
=
name
)
except
Exception
:
raise
TypeError
(
"Could not convert
%
s to TensorType"
%
x
,
type
(
x
))
constant
.
enable
=
True
constant_cache
=
{}
def
_obj_is_wrappable_as_tensor
(
x
):
try
:
constant
(
x
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论