Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
043c3eef
提交
043c3eef
authored
11月 04, 2016
作者:
Pascal Lamblin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Some pep8
上级
7215c905
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
207 行增加
和
251 行删除
+207
-251
test_opt.py
theano/tensor/tests/test_opt.py
+207
-251
没有找到文件。
theano/tensor/tests/test_opt.py
浏览文件 @
043c3eef
from
__future__
import
absolute_import
,
print_function
,
division
from
__future__
import
absolute_import
,
print_function
,
division
# PENDING REWRITE OF tensor_opt.py
import
copy
import
copy
import
logging
import
logging
import
pickle
import
os
import
os
import
sys
import
sys
import
time
import
time
...
@@ -13,8 +11,6 @@ import numpy
...
@@ -13,8 +11,6 @@ import numpy
from
six.moves
import
xrange
from
six.moves
import
xrange
from
nose.plugins.skip
import
SkipTest
from
nose.plugins.skip
import
SkipTest
from
nose.tools
import
assert_raises
,
assert_true
from
nose.tools
import
assert_raises
,
assert_true
from
numpy.testing
import
dec
from
numpy.testing.noseclasses
import
KnownFailureTest
import
theano
import
theano
import
theano.scalar
as
scal
import
theano.scalar
as
scal
...
@@ -30,47 +26,44 @@ from theano import shared
...
@@ -30,47 +26,44 @@ from theano import shared
from
theano.gof
import
FunctionGraph
from
theano.gof
import
FunctionGraph
import
theano.tensor.opt
as
opt
import
theano.tensor.opt
as
opt
from
theano.tensor.opt
import
(
from
theano.tensor.opt
import
(
local_add_specialize
,
local_add_specialize
,
local_dimshuffle_lift
,
local_dimshuffle_lift
,
local_useless_dimshuffle_in_reshape
,
local_useless_dimshuffle_in_reshape
,
local_useless_alloc
,
local_useless_alloc
,
local_merge_alloc
,
local_merge_alloc
,
local_greedy_distributor
,
local_greedy_distributor
,
local_useless_reshape
,
local_useless_reshape
,
local_reshape_to_dimshuffle
,
local_reshape_to_dimshuffle
,
mul_canonizer
,
mul_canonizer
,
Shape_i
,
Shape_i
,
Assert
,
Assert
,
MakeVector
,
MakeVector
,
make_vector
,
make_vector
,
local_expm1
,
local_canonicalize_alloc
local_canonicalize_alloc
)
)
from
theano
import
tensor
from
theano
import
tensor
from
theano
import
tensor
as
T
from
theano
import
tensor
as
T
from
theano.tensor
import
scalar
,
iscalar
,
lscalar
,
fscalar
,
dscalar
from
theano.tensor
import
scalar
,
iscalar
,
lscalar
,
fscalar
,
dscalar
from
theano.tensor
import
vector
,
ivector
,
lvector
,
fvector
,
dvector
from
theano.tensor
import
vector
,
lvector
,
fvector
,
dvector
from
theano.tensor
import
matrix
,
imatrix
,
lmatrix
,
fmatrix
,
dmatrix
,
tensor3
from
theano.tensor
import
matrix
,
fmatrix
,
dmatrix
,
tensor3
from
theano.tensor
import
scalars
,
vectors
,
matrices
,
fmatrices
,
dmatrices
from
theano.tensor
import
vectors
,
matrices
,
fmatrices
,
dmatrices
from
theano.tensor
import
(
from
theano.tensor
import
(
AdvancedSubtensor
,
AdvancedSubtensor
,
AdvancedSubtensor1
,
AdvancedSubtensor1
,
as_tensor_variable
,
as_tensor_variable
,
IncSubtensor
,
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
,
AdvancedIncSubtensor1
,
inplace
,
inplace
,
Join
,
Join
,
join
,
join
,
Subtensor
,
Subtensor
,
TensorType
,
TensorType
,
tile
tile
)
)
from
theano.tensor.elemwise
import
DimShuffle
from
theano.tensor.elemwise
import
DimShuffle
from
theano.tensor.type
import
values_eq_approx_remove_nan
from
theano.tensor.type
import
values_eq_approx_remove_nan
from
theano.tests
import
unittest_tools
as
utt
from
theano.tests
import
unittest_tools
as
utt
from
theano.compile.mode
import
optdb
from
theano.compile
import
Mode
from
theano.gof.opt
import
check_stack_trace
,
out2in
from
theano.gof.opt
import
check_stack_trace
,
out2in
from
nose.plugins.attrib
import
attr
from
nose.plugins.attrib
import
attr
...
@@ -79,7 +72,6 @@ if mode_opt == 'FAST_COMPILE':
...
@@ -79,7 +72,6 @@ if mode_opt == 'FAST_COMPILE':
mode_opt
=
'FAST_RUN'
mode_opt
=
'FAST_RUN'
mode_opt
=
theano
.
compile
.
mode
.
get_mode
(
mode_opt
)
mode_opt
=
theano
.
compile
.
mode
.
get_mode
(
mode_opt
)
ds
=
lambda
x
,
y
:
DimShuffle
(
x
.
type
.
broadcastable
,
y
)(
x
)
dimshuffle_lift
=
out2in
(
local_dimshuffle_lift
)
dimshuffle_lift
=
out2in
(
local_dimshuffle_lift
)
_optimizer_stabilize
=
gof
.
Query
(
include
=
[
'fast_run'
])
_optimizer_stabilize
=
gof
.
Query
(
include
=
[
'fast_run'
])
...
@@ -94,6 +86,10 @@ _optimizer_fast_run = gof.Query(include=['fast_run'])
...
@@ -94,6 +86,10 @@ _optimizer_fast_run = gof.Query(include=['fast_run'])
_optimizer_fast_run
=
compile
.
optdb
.
query
(
_optimizer_fast_run
)
_optimizer_fast_run
=
compile
.
optdb
.
query
(
_optimizer_fast_run
)
def
ds
(
x
,
y
):
return
DimShuffle
(
x
.
type
.
broadcastable
,
y
)(
x
)
def
optimize
(
g
,
level
=
'fast_run'
):
def
optimize
(
g
,
level
=
'fast_run'
):
if
level
==
'fast_run'
:
if
level
==
'fast_run'
:
_optimizer_fast_run
.
optimize
(
g
)
_optimizer_fast_run
.
optimize
(
g
)
...
@@ -138,8 +134,8 @@ class test_dimshuffle_lift(unittest.TestCase):
...
@@ -138,8 +134,8 @@ class test_dimshuffle_lift(unittest.TestCase):
x
,
y
,
z
=
inputs
()
x
,
y
,
z
=
inputs
()
e
=
ds
(
ds
(
ds
(
x
,
(
0
,
'x'
,
1
)),
(
2
,
0
,
'x'
,
1
)),
(
1
,
0
))
e
=
ds
(
ds
(
ds
(
x
,
(
0
,
'x'
,
1
)),
(
2
,
0
,
'x'
,
1
)),
(
1
,
0
))
g
=
FunctionGraph
([
x
],
[
e
])
g
=
FunctionGraph
([
x
],
[
e
])
self
.
assertTrue
(
str
(
g
)
==
"[InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
self
.
assertTrue
(
str
(
g
)
==
(
"[InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
"(InplaceDimShuffle{0,x,1}(x)))]"
,
"(InplaceDimShuffle{0,x,1}(x)))]"
)
,
str
(
g
))
str
(
g
))
dimshuffle_lift
.
optimize
(
g
)
dimshuffle_lift
.
optimize
(
g
)
self
.
assertTrue
(
str
(
g
)
==
"[x]"
,
str
(
g
))
self
.
assertTrue
(
str
(
g
)
==
"[x]"
,
str
(
g
))
...
@@ -259,7 +255,7 @@ def test_local_useless_dimshuffle_in_reshape():
...
@@ -259,7 +255,7 @@ def test_local_useless_dimshuffle_in_reshape():
h
=
FunctionGraph
([
mat
],
[
reshape_dimshuffle_mat2
])
h
=
FunctionGraph
([
mat
],
[
reshape_dimshuffle_mat2
])
str_h
=
str
(
h
)
str_h
=
str
(
h
)
useless_dimshuffle_in_reshape
.
optimize
(
h
)
useless_dimshuffle_in_reshape
.
optimize
(
h
)
assert_true
(
str
(
h
)
==
str
(
h
)
)
assert_true
(
str
(
h
)
==
str
_h
)
def
test_add_canonizer_problem0
():
def
test_add_canonizer_problem0
():
...
@@ -269,6 +265,7 @@ def test_add_canonizer_problem0():
...
@@ -269,6 +265,7 @@ def test_add_canonizer_problem0():
r
=
segment_labels
*
5
r
=
segment_labels
*
5
f
=
function
([
label
],
r
)
f
=
function
([
label
],
r
)
f
(
3
)
class
test_greedy_distribute
(
unittest
.
TestCase
):
class
test_greedy_distribute
(
unittest
.
TestCase
):
...
@@ -300,8 +297,8 @@ class test_greedy_distribute(unittest.TestCase):
...
@@ -300,8 +297,8 @@ class test_greedy_distribute(unittest.TestCase):
eps
=
scalar
(
'eps'
)
eps
=
scalar
(
'eps'
)
s
=
scalar
(
's'
)
s
=
scalar
(
's'
)
#r = theano.tensor.mul(theano.tensor.fill(x, 2.*a), x/a , (y+z) , a)
#
r = theano.tensor.mul(theano.tensor.fill(x, 2.*a), x/a , (y+z) , a)
#r = theano.tensor.mul((x/a+y) , a, z)
#
r = theano.tensor.mul((x/a+y) , a, z)
r
=
tensor
.
mul
(
s
-
1
,
r
=
tensor
.
mul
(
s
-
1
,
eps
+
x
/
s
,
eps
+
x
/
s
,
eps
+
y
/
s
,
eps
+
y
/
s
,
...
@@ -326,16 +323,16 @@ class test_canonize(unittest.TestCase):
...
@@ -326,16 +323,16 @@ class test_canonize(unittest.TestCase):
def
test_muldiv
(
self
):
def
test_muldiv
(
self
):
x
,
y
,
z
=
matrices
(
'xyz'
)
x
,
y
,
z
=
matrices
(
'xyz'
)
a
,
b
,
c
,
d
=
matrices
(
'abcd'
)
a
,
b
,
c
,
d
=
matrices
(
'abcd'
)
#
e = (2.0 * x) / (2.0 * y)
#
e = (2.0 * x) / (2.0 * y)
#
e = (2.0 * x) / (4.0 * y)
#
e = (2.0 * x) / (4.0 * y)
#
e = x / (y / z)
#
e = x / (y / z)
#
e = (x * y) / x
#
e = (x * y) / x
#
e = (x / y) * (y / z) * (z / x)
#
e = (x / y) * (y / z) * (z / x)
#
e = (a / b) * (b / c) * (c / d)
#
e = (a / b) * (b / c) * (c / d)
#
e = (a * b) / (b * c) / (c * d)
#
e = (a * b) / (b * c) / (c * d)
#
e = 2 * x / 2
#
e = 2 * x / 2
#
e = x / y / x
#
e = x / y / x
#
e = (x / x) * (y / y)
#
e = (x / x) * (y / y)
e
=
(
-
1
*
x
)
/
y
/
(
-
2
*
z
)
e
=
(
-
1
*
x
)
/
y
/
(
-
2
*
z
)
g
=
FunctionGraph
([
x
,
y
,
z
,
a
,
b
,
c
,
d
],
[
e
])
g
=
FunctionGraph
([
x
,
y
,
z
,
a
,
b
,
c
,
d
],
[
e
])
print
(
pprint
(
g
.
outputs
[
0
]))
print
(
pprint
(
g
.
outputs
[
0
]))
...
@@ -355,60 +352,60 @@ class test_canonize(unittest.TestCase):
...
@@ -355,60 +352,60 @@ class test_canonize(unittest.TestCase):
shp
=
(
5
,
5
)
shp
=
(
5
,
5
)
fx
,
fy
,
fz
=
fmatrices
(
'xyz'
)
fx
,
fy
,
fz
=
fmatrices
(
'xyz'
)
dx
,
dy
,
dz
=
dmatrices
(
'xyz'
)
dx
,
dy
,
dz
=
dmatrices
(
'xyz'
)
fv
=
fvector
(
'r'
)
.
dimshuffle
(
'x'
,
0
)
#
fv = fvector('r').dimshuffle('x', 0)
dv
=
dvector
(
's'
)
.
dimshuffle
(
'x'
,
0
)
#
dv = dvector('s').dimshuffle('x', 0)
fxv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float32'
)
fxv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float32'
)
fyv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float32'
)
fyv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float32'
)
fzv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float32'
)
fzv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float32'
)
fvv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
shp
[
0
]),
dtype
=
'float32'
)
.
reshape
(
1
,
shp
[
0
])
#
fvv = theano._asarray(numpy.random.rand(shp[0]), dtype='float32').reshape(1, shp[0])
dxv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float64'
)
#
dxv = theano._asarray(numpy.random.rand(*shp), dtype='float64')
dyv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float64'
)
#
dyv = theano._asarray(numpy.random.rand(*shp), dtype='float64')
dzv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float64'
)
#
dzv = theano._asarray(numpy.random.rand(*shp), dtype='float64')
dvv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
shp
[
0
]),
dtype
=
'float64'
)
.
reshape
(
1
,
shp
[
0
])
#
dvv = theano._asarray(numpy.random.rand(shp[0]), dtype='float64').reshape(1, shp[0])
cases
=
[
cases
=
[
(
fx
+
fy
,
(
fx
,
fy
),
(
fxv
,
fyv
),
1
,
'float32'
),
(
fx
+
fy
,
(
fx
,
fy
),
(
fxv
,
fyv
),
1
,
'float32'
),
(
fx
*
fy
,
(
fx
,
fy
),
(
fxv
,
fyv
),
1
,
'float32'
),
(
fx
*
fy
,
(
fx
,
fy
),
(
fxv
,
fyv
),
1
,
'float32'
),
#
(fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
#
(fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
#
(dx+dy+dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'),
#
(dx+dy+dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'),
#
(fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
#
(fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
#
(dx*dy*dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'),
#
(dx*dy*dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'),
#
(fx*fy*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
#
(fx*fy*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
#
(dx*dy*(dx+dy+dz),(dx,dy,dz),(dxv,dyv,dzv),2,'float64'),
#
(dx*dy*(dx+dy+dz),(dx,dy,dz),(dxv,dyv,dzv),2,'float64'),
# (fx*fy*(fx+fy+dz),(fx,fy,dz),(dxv,dyv,dzv),2,'float64'),#
check mixed type add
# (fx*fy*(fx+fy+dz),(fx,fy,dz),(dxv,dyv,dzv),2,'float64'), #
check mixed type add
# (dz*fy*(fx+fy),(fx,fy,dz),(dxv,dyv,dzv),2,'float64'),#
check mixed type mul
# (dz*fy*(fx+fy),(fx,fy,dz),(dxv,dyv,dzv),2,'float64'), #
check mixed type mul
# check with dimshuffle of constant
# check with dimshuffle of constant
(
fx
+
fy
+
fz
+
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
{
'custom'
:
(
fx
+
fy
+
fz
+
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
(
fx
*
fy
*
fz
*
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
{
'custom'
:
(
fx
*
fy
*
fz
*
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
#
(2+fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
#
(2+fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
#
(2*fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
#
(2*fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(
2
+
fx
+
fy
+
fz
+
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
{
(
2
+
fx
+
fy
+
fz
+
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
(
2
*
fx
*
fy
*
fz
*
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
{
(
2
*
fx
*
fy
*
fz
*
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
#
(fx*fy*2*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
#
(fx*fy*2*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
#
(fx*fy*(2+fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
#
(fx*fy*(2+fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
(
fx
*
fy
*
2
*
(
fx
+
fy
+
fz
+
2
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
2
,
{
(
fx
*
fy
*
2
*
(
fx
+
fy
+
fz
+
2
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
2
,
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
# check with broadcast of row
# check with broadcast of row
#
(fx+fy+fz+fv,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
#
(fx+fy+fz+fv,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
#
(fx*fy*fz*fv,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
#
(fx*fy*fz*fv,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
#
(fv+fx+fy+fz,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
#
(fv+fx+fy+fz,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
#
(fv*fx*fy*fz,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
#
(fv*fx*fy*fz,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
#
(fx*fy*fv*(fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
#
(fx*fy*fv*(fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
#
(fx*fy*(fv+fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
#
(fx*fy*(fv+fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
#
(fx*fy*fv*(fv+fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
#
(fx*fy*fv*(fv+fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
#
(dx+dy+dz+dv,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'),
#
(dx+dy+dz+dv,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'),
#
(dx*dy*dz*dv,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'),
#
(dx*dy*dz*dv,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'),
#
(dv+dx+dy+dz,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'),
#
(dv+dx+dy+dz,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'),
#
(dv*dx*dy*dz,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'),
#
(dv*dx*dy*dz,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'),
#
(dx*dy*dv*(dx+dy+dz),(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),2,'float64'),
#
(dx*dy*dv*(dx+dy+dz),(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),2,'float64'),
#
(dx*dy*(dv+dx+dy+dz),(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),2,'float64'),
#
(dx*dy*(dv+dx+dy+dz),(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),2,'float64'),
#
(dx*dy*dv*(dv+dx+dy+dz),(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),2,'float64'),
#
(dx*dy*dv*(dv+dx+dy+dz),(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),2,'float64'),
]
# [10:11]
]
# [10:11]
#
print cases
#
print cases
# We must be sure that the Canonizer is working, but that we don't have other
# We must be sure that the Canonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion
...
@@ -457,61 +454,38 @@ class test_canonize(unittest.TestCase):
...
@@ -457,61 +454,38 @@ class test_canonize(unittest.TestCase):
(
dx
+
dy
+
dz
,
(
dx
,
dy
,
dz
),
(
dxv
,
dyv
,
dzv
),
1
,
'float64'
),
(
dx
+
dy
+
dz
,
(
dx
,
dy
,
dz
),
(
dxv
,
dyv
,
dzv
),
1
,
'float64'
),
(
fx
*
fy
*
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'float32'
),
(
fx
*
fy
*
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'float32'
),
(
dx
*
dy
*
dz
,
(
dx
,
dy
,
dz
),
(
dxv
,
dyv
,
dzv
),
1
,
'float64'
),
(
dx
*
dy
*
dz
,
(
dx
,
dy
,
dz
),
(
dxv
,
dyv
,
dzv
),
1
,
'float64'
),
(
fx
*
fy
*
(
fx
+
fy
+
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
(
fx
*
fy
*
(
fx
+
fy
+
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
2
,
'float32'
),
fzv
),
2
,
'float32'
),
(
dx
*
dy
*
(
dx
+
dy
+
dz
),
(
dx
,
dy
,
dz
),
(
dxv
,
dyv
,
dzv
),
2
,
'float64'
),
(
dx
*
dy
*
(
dx
+
dy
+
dz
),
(
dx
,
dy
,
dz
),
(
dxv
,
dyv
,
(
fx
*
fy
*
(
fx
+
fy
+
dz
),
(
fx
,
fy
,
dz
),
(
dxv
,
dyv
,
dzv
),
2
,
'float64'
),
# check mixed type add
dzv
),
2
,
'float64'
),
(
dz
*
fy
*
(
fx
+
fy
),
(
fx
,
fy
,
dz
),
(
dxv
,
dyv
,
dzv
),
2
,
'float64'
),
# check mixed type mul
(
fx
*
fy
*
(
fx
+
fy
+
dz
),
(
fx
,
fy
,
dz
),
(
dxv
,
dyv
,
dzv
),
2
,
'float64'
),
# check mixed type add
(
dz
*
fy
*
(
fx
+
fy
),
(
fx
,
fy
,
dz
),
(
dxv
,
dyv
,
dzv
),
2
,
'float64'
),
# check mixed type mul
# check with dimshuffle of constant
# check with dimshuffle of constant
(
fx
+
fy
+
fz
+
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'float32'
),
(
fx
+
fy
+
fz
+
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'float32'
),
(
fx
*
fy
*
fz
*
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'float32'
),
(
fx
*
fy
*
fz
*
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'float32'
),
(
2
+
fx
+
fy
+
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'float32'
),
(
2
+
fx
+
fy
+
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'float32'
),
(
2
*
fx
*
fy
*
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'float32'
),
(
2
*
fx
*
fy
*
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'float32'
),
(
2
+
fx
+
fy
+
fz
+
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
(
2
+
fx
+
fy
+
fz
+
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'float32'
),
fzv
),
1
,
'float32'
),
(
2
*
fx
*
fy
*
fz
*
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
'float32'
),
(
2
*
fx
*
fy
*
fz
*
2
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
(
fx
*
fy
*
2
*
(
fx
+
fy
+
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
2
,
'float32'
),
fzv
),
1
,
'float32'
),
(
fx
*
fy
*
(
2
+
fx
+
fy
+
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
2
,
'float32'
),
(
fx
*
fy
*
2
*
(
fx
+
fy
+
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
(
fx
*
fy
*
2
*
(
fx
+
fy
+
fz
+
2
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
2
,
'float32'
),
fzv
),
2
,
'float32'
),
(
fx
*
fy
*
(
2
+
fx
+
fy
+
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
2
,
'float32'
),
(
fx
*
fy
*
2
*
(
fx
+
fy
+
fz
+
2
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
2
,
'float32'
),
# check with broadcast of row
# check with broadcast of row
(
fx
+
fy
+
fz
+
fv
,
(
fx
,
fy
,
fz
,
fv
),
(
fxv
,
fyv
,
fzv
,
(
fx
+
fy
+
fz
+
fv
,
(
fx
,
fy
,
fz
,
fv
),
(
fxv
,
fyv
,
fzv
,
fvv
),
1
,
'float32'
),
fvv
),
1
,
'float32'
),
(
fx
*
fy
*
fz
*
fv
,
(
fx
,
fy
,
fz
,
fv
),
(
fxv
,
fyv
,
fzv
,
fvv
),
1
,
'float32'
),
(
fx
*
fy
*
fz
*
fv
,
(
fx
,
fy
,
fz
,
fv
),
(
fxv
,
fyv
,
fzv
,
(
fv
+
fx
+
fy
+
fz
,
(
fx
,
fy
,
fz
,
fv
),
(
fxv
,
fyv
,
fzv
,
fvv
),
1
,
'float32'
),
fvv
),
1
,
'float32'
),
(
fv
*
fx
*
fy
*
fz
,
(
fx
,
fy
,
fz
,
fv
),
(
fxv
,
fyv
,
fzv
,
fvv
),
1
,
'float32'
),
(
fv
+
fx
+
fy
+
fz
,
(
fx
,
fy
,
fz
,
fv
),
(
fxv
,
fyv
,
fzv
,
(
fx
*
fy
*
fv
*
(
fx
+
fy
+
fz
),
(
fx
,
fy
,
fz
,
fv
),
(
fxv
,
fyv
,
fzv
,
fvv
),
2
,
'float32'
),
fvv
),
1
,
'float32'
),
(
fx
*
fy
*
(
fv
+
fx
+
fy
+
fz
),
(
fx
,
fy
,
fz
,
fv
),
(
fxv
,
fyv
,
fzv
,
fvv
),
2
,
'float32'
),
(
fv
*
fx
*
fy
*
fz
,
(
fx
,
fy
,
fz
,
fv
),
(
fxv
,
fyv
,
fzv
,
(
fx
*
fy
*
fv
*
(
fv
+
fx
+
fy
+
fz
),
(
fx
,
fy
,
fz
,
fv
),
(
fxv
,
fyv
,
fzv
,
fvv
),
2
,
'float32'
),
fvv
),
1
,
'float32'
),
(
dx
+
dy
+
dz
+
dv
,
(
dx
,
dy
,
dz
,
dv
),
(
dxv
,
dyv
,
dzv
,
dvv
),
1
,
'float64'
),
(
fx
*
fy
*
fv
*
(
fx
+
fy
+
fz
),
(
fx
,
fy
,
fz
,
fv
),
(
fxv
,
fyv
,
(
dx
*
dy
*
dz
*
dv
,
(
dx
,
dy
,
dz
,
dv
),
(
dxv
,
dyv
,
dzv
,
dvv
),
1
,
'float64'
),
fzv
,
fvv
),
2
,
'float32'
),
(
dv
+
dx
+
dy
+
dz
,
(
dx
,
dy
,
dz
,
dv
),
(
dxv
,
dyv
,
dzv
,
dvv
),
1
,
'float64'
),
(
fx
*
fy
*
(
fv
+
fx
+
fy
+
fz
),
(
fx
,
fy
,
fz
,
fv
),
(
fxv
,
fyv
,
(
dv
*
dx
*
dy
*
dz
,
(
dx
,
dy
,
dz
,
dv
),
(
dxv
,
dyv
,
dzv
,
dvv
),
1
,
'float64'
),
fzv
,
fvv
),
2
,
'float32'
),
(
dx
*
dy
*
dv
*
(
dx
+
dy
+
dz
),
(
dx
,
dy
,
dz
,
dv
),
(
dxv
,
dyv
,
dzv
,
dvv
),
2
,
'float64'
),
(
fx
*
fy
*
fv
*
(
fv
+
fx
+
fy
+
fz
),
(
fx
,
fy
,
fz
,
fv
),
(
fxv
,
fyv
,
fzv
,
(
dx
*
dy
*
(
dv
+
dx
+
dy
+
dz
),
(
dx
,
dy
,
dz
,
dv
),
(
dxv
,
dyv
,
dzv
,
dvv
),
2
,
'float64'
),
fvv
),
2
,
'float32'
),
(
dx
*
dy
*
dv
*
(
dv
+
dx
+
dy
+
dz
),
(
dx
,
dy
,
dz
,
dv
),
(
dxv
,
dyv
,
dzv
,
dvv
),
2
,
'float64'
),
(
dx
+
dy
+
dz
+
dv
,
(
dx
,
dy
,
dz
,
dv
),
(
dxv
,
dyv
,
dzv
,
dvv
),
1
,
'float64'
),
(
dx
*
dy
*
dz
*
dv
,
(
dx
,
dy
,
dz
,
dv
),
(
dxv
,
dyv
,
dzv
,
dvv
),
1
,
'float64'
),
(
dv
+
dx
+
dy
+
dz
,
(
dx
,
dy
,
dz
,
dv
),
(
dxv
,
dyv
,
dzv
,
dvv
),
1
,
'float64'
),
(
dv
*
dx
*
dy
*
dz
,
(
dx
,
dy
,
dz
,
dv
),
(
dxv
,
dyv
,
dzv
,
dvv
),
1
,
'float64'
),
(
dx
*
dy
*
dv
*
(
dx
+
dy
+
dz
),
(
dx
,
dy
,
dz
,
dv
),
(
dxv
,
dyv
,
dzv
,
dvv
),
2
,
'float64'
),
(
dx
*
dy
*
(
dv
+
dx
+
dy
+
dz
),
(
dx
,
dy
,
dz
,
dv
),
(
dxv
,
dyv
,
dzv
,
dvv
),
2
,
'float64'
),
(
dx
*
dy
*
dv
*
(
dv
+
dx
+
dy
+
dz
),
(
dx
,
dy
,
dz
,
dv
),
(
dxv
,
dyv
,
dzv
,
dvv
),
2
,
'float64'
),
]
# [10:11]
]
# [10:11]
#
print cases
#
print cases
# We must be sure that the Canonizer is working, but that we don't have other
# We must be sure that the Canonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion
...
@@ -568,11 +542,11 @@ class test_canonize(unittest.TestCase):
...
@@ -568,11 +542,11 @@ class test_canonize(unittest.TestCase):
'local_elemwise_fusion'
)
'local_elemwise_fusion'
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
# test x / x -> 1
# test x / x -> 1
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
enumerate
([
(
fx
/
fx
,
[
fx
],
[
fxv
],
'float32'
),
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
enumerate
([
(
dx
/
dx
,
[
dx
],
[
dxv
],
'float64
'
),
(
fx
/
fx
,
[
fx
],
[
fxv
],
'float32
'
),
(
fv
/
fv
,
[
fv
],
[
fvv
],
'float32
'
),
(
dx
/
dx
,
[
dx
],
[
dxv
],
'float64
'
),
(
dv
/
dv
,
[
dv
],
[
dvv
],
'float64
'
),
(
fv
/
fv
,
[
fv
],
[
fvv
],
'float32
'
),
]):
(
dv
/
dv
,
[
dv
],
[
dvv
],
'float64'
)
]):
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
mode
=
mode
)
mode
=
mode
)
out
=
f
(
*
val_inputs
)
out
=
f
(
*
val_inputs
)
...
@@ -591,16 +565,16 @@ class test_canonize(unittest.TestCase):
...
@@ -591,16 +565,16 @@ class test_canonize(unittest.TestCase):
# test (x * y) / x -> y
# test (x * y) / x -> y
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
nb_elemwise
,
out_dtype
)
in
enumerate
([
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
nb_elemwise
,
out_dtype
)
in
enumerate
([
((
dx
*
dy
)
/
dx
,
[
dx
,
dy
],
[
dxv
,
dyv
],
0
,
'float64'
),
((
dx
*
dy
)
/
dx
,
[
dx
,
dy
],
[
dxv
,
dyv
],
0
,
'float64'
),
((
fx
*
fy
)
/
fx
,
[
fx
,
fy
],
[
fxv
,
fyv
],
0
,
'float32'
),
((
fx
*
fy
)
/
fx
,
[
fx
,
fy
],
[
fxv
,
fyv
],
0
,
'float32'
),
((
dv
*
dy
)
/
dv
,
[
dv
,
dy
],
[
dvv
,
dyv
],
0
,
'float64'
),
((
dv
*
dy
)
/
dv
,
[
dv
,
dy
],
[
dvv
,
dyv
],
0
,
'float64'
),
((
fv
*
fy
)
/
fv
,
[
fv
,
fy
],
[
fvv
,
fyv
],
0
,
'float32'
),
((
fv
*
fy
)
/
fv
,
[
fv
,
fy
],
[
fvv
,
fyv
],
0
,
'float32'
),
# must broadcast as their
is a dimshuffle in the computation
# must broadcast as there
is a dimshuffle in the computation
((
dx
*
dv
)
/
dx
,
[
dx
,
dv
],
[
dxv
,
dvv
],
1
,
'float64'
),
((
dx
*
dv
)
/
dx
,
[
dx
,
dv
],
[
dxv
,
dvv
],
1
,
'float64'
),
# topo: [Elemwise{second,no_inplace}(x, <TensorType(float64, row)>)]
# topo: [Elemwise{second,no_inplace}(x, <TensorType(float64, row)>)]
((
fx
*
fv
)
/
fx
,
[
fx
,
fv
],
[
fxv
,
fvv
],
1
,
'float32'
)
((
fx
*
fv
)
/
fx
,
[
fx
,
fv
],
[
fxv
,
fvv
],
1
,
'float32'
)
# topo: [Elemwise{second,no_inplace}(x, <TensorType(float32, row)>)]
# topo: [Elemwise{second,no_inplace}(x, <TensorType(float32, row)>)]
]):
]):
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
mode
=
mode
)
mode
=
mode
)
out
=
f
(
*
val_inputs
)
out
=
f
(
*
val_inputs
)
...
@@ -614,19 +588,17 @@ class test_canonize(unittest.TestCase):
...
@@ -614,19 +588,17 @@ class test_canonize(unittest.TestCase):
# test x / y / x -> 1 / y
# test x / y / x -> 1 / y
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
nb_elemwise
,
out_dtype
)
in
enumerate
([
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
nb_elemwise
,
out_dtype
)
in
enumerate
([
((
dx
/
dy
)
/
dx
,
[
dx
,
dy
],
[
dxv
,
dyv
],
1
,
'float64'
),
((
dx
/
dy
)
/
dx
,
[
dx
,
dy
],
[
dxv
,
dyv
],
1
,
'float64'
),
((
fx
/
fy
)
/
fx
,
[
fx
,
fy
],
[
fxv
,
fyv
],
1
,
'float32'
),
((
fx
/
fy
)
/
fx
,
[
fx
,
fy
],
[
fxv
,
fyv
],
1
,
'float32'
),
((
dv
/
dy
)
/
dv
,
[
dv
,
dy
],
[
dvv
,
dyv
],
1
,
'float64'
),
((
dv
/
dy
)
/
dv
,
[
dv
,
dy
],
[
dvv
,
dyv
],
1
,
'float64'
),
((
fv
/
fy
)
/
fv
,
[
fv
,
fy
],
[
fvv
,
fyv
],
1
,
'float32'
),
((
fv
/
fy
)
/
fv
,
[
fv
,
fy
],
[
fvv
,
fyv
],
1
,
'float32'
),
# must broadcast as their is a dimshuffle in the computation
# must broadcast as their is a dimshuffle in the computation
((
dx
/
dv
)
/
dx
,
[
dx
,
dv
],
[
dxv
,
dvv
],
1
,
'float64'
),
((
dx
/
dv
)
/
dx
,
[
dx
,
dv
],
[
dxv
,
dvv
],
1
,
'float64'
),
# topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float64, row)>), Alloc]
# topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float64, row)>), Alloc]
((
fx
/
fv
)
/
fx
,
[
fx
,
fv
],
[
fxv
,
fvv
],
1
,
'float32'
),
((
fx
/
fv
)
/
fx
,
[
fx
,
fv
],
[
fxv
,
fvv
],
1
,
'float32'
),
# topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float32, row)>), Alloc]
# topo:[Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float32, row)>), Alloc]
]):
]):
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
mode
=
mode
)
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
mode
=
mode
)
out
=
f
(
*
val_inputs
)
out
=
f
(
*
val_inputs
)
utt
.
assert_allclose
(
out
,
(
1
/
val_inputs
[
1
]))
utt
.
assert_allclose
(
out
,
(
1
/
val_inputs
[
1
]))
topo
=
f
.
maker
.
fgraph
.
toposort
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
...
@@ -639,69 +611,61 @@ class test_canonize(unittest.TestCase):
...
@@ -639,69 +611,61 @@ class test_canonize(unittest.TestCase):
# test (a / b) * (b / c) * (c / d) -> a / d
# test (a / b) * (b / c) * (c / d) -> a / d
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
enumerate
([
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
enumerate
([
((
dx
/
dy
)
*
(
dy
/
dz
)
*
(
dz
/
dw
),
[
dx
,
dy
,
dz
,
dw
],
[
dxv
,
dyv
,
dzv
,
dwv
],
'float64'
),
((
dx
/
dy
)
*
(
dy
/
dz
)
*
(
dz
/
dw
),
[
dx
,
dy
,
dz
,
dw
],
[
dxv
,
dyv
,
dzv
,
dwv
],
'float64'
),
((
fx
/
fy
)
*
(
fy
/
fz
)
*
(
fz
/
fw
),
[
fx
,
fy
,
fz
,
fw
],
[
fxv
,
fyv
,
fzv
,
fwv
],
'float32'
),
((
fx
/
fy
)
*
(
fy
/
fz
)
*
(
fz
/
fw
),
[
fx
,
fy
,
fz
,
fw
],
[
fxv
,
fyv
,
fzv
,
fwv
],
'float32'
),
((
dv
/
dy
)
*
(
dy
/
dz
)
*
(
dz
/
dw
),
[
dv
,
dy
,
dz
,
dw
],
[
dvv
,
dyv
,
dzv
,
dwv
],
'float64'
),
((
dv
/
dy
)
*
(
dy
/
dz
)
*
(
dz
/
dw
),
[
dv
,
dy
,
dz
,
dw
],
[
dvv
,
dyv
,
dzv
,
dwv
],
'float64'
),
((
fv
/
fy
)
*
(
fy
/
fz
)
*
(
fz
/
fw
),
[
fv
,
fy
,
fz
,
fw
],
[
fvv
,
fyv
,
fzv
,
fwv
],
'float32'
),
((
fv
/
fy
)
*
(
fy
/
fz
)
*
(
fz
/
fw
),
[
fv
,
fy
,
fz
,
fw
],
[
fvv
,
fyv
,
fzv
,
fwv
],
'float32'
),
((
dx
/
dv
)
*
(
dv
/
dz
)
*
(
dz
/
dw
),
[
dx
,
dv
,
dz
,
dw
],
[
dxv
,
dvv
,
dzv
,
dwv
],
'float64'
),
((
dx
/
dv
)
*
(
dv
/
dz
)
*
(
dz
/
dw
),
[
dx
,
dv
,
dz
,
dw
],
[
dxv
,
dvv
,
dzv
,
dwv
],
'float64'
),
((
fx
/
fv
)
*
(
fv
/
fz
)
*
(
fz
/
fw
),
[
fx
,
fv
,
fz
,
fw
],
[
fxv
,
fvv
,
fzv
,
fwv
],
'float32'
),
((
fx
/
fv
)
*
(
fv
/
fz
)
*
(
fz
/
fw
),
[
fx
,
fv
,
fz
,
fw
],
[
fxv
,
fvv
,
fzv
,
fwv
],
'float32'
),
((
dx
/
dy
)
*
(
dy
/
dv
)
*
(
dv
/
dw
),
[
dx
,
dy
,
dv
,
dw
],
[
dxv
,
dyv
,
dvv
,
dwv
],
'float64'
),
((
dx
/
dy
)
*
(
dy
/
dv
)
*
(
dv
/
dw
),
[
dx
,
dy
,
dv
,
dw
],
[
dxv
,
dyv
,
dvv
,
dwv
],
'float64'
),
((
fx
/
fy
)
*
(
fy
/
fv
)
*
(
fv
/
fw
),
[
fx
,
fy
,
fv
,
fw
],
[
fxv
,
fyv
,
fvv
,
fwv
],
'float32'
),
((
fx
/
fy
)
*
(
fy
/
fv
)
*
(
fv
/
fw
),
[
fx
,
fy
,
fv
,
fw
],
[
fxv
,
fyv
,
fvv
,
fwv
],
'float32'
),
((
dx
/
dy
)
*
(
dy
/
dz
)
*
(
dz
/
dv
),
[
dx
,
dy
,
dz
,
dv
],
[
dxv
,
dyv
,
dzv
,
dvv
],
'float64'
),
((
dx
/
dy
)
*
(
dy
/
dz
)
*
(
dz
/
dv
),
[
dx
,
dy
,
dz
,
dv
],
[
dxv
,
dyv
,
dzv
,
dvv
],
'float64'
),
((
fx
/
fy
)
*
(
fy
/
fz
)
*
(
fz
/
fv
),
[
fx
,
fy
,
fz
,
fv
],
[
fxv
,
fyv
,
fzv
,
fvv
],
'float32'
),
((
fx
/
fy
)
*
(
fy
/
fz
)
*
(
fz
/
fv
),
[
fx
,
fy
,
fz
,
fv
],
[
fxv
,
fyv
,
fzv
,
fvv
],
'float32'
),
]):
]):
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
mode
=
mode
)
mode
=
mode
)
out
=
f
(
*
val_inputs
)
out
=
f
(
*
val_inputs
)
utt
.
assert_allclose
(
out
,
(
val_inputs
[
0
]
/
val_inputs
[
3
]))
utt
.
assert_allclose
(
out
,
(
val_inputs
[
0
]
/
val_inputs
[
3
]))
topo
=
f
.
maker
.
fgraph
.
toposort
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
1
assert
len
(
topo
)
==
1
assert
isinstance
(
topo
[
0
]
.
op
,
(
T
.
Elemwise
,
))
assert
isinstance
(
topo
[
0
]
.
op
,
(
T
.
Elemwise
,
))
assert
isinstance
(
topo
[
0
]
.
op
.
scalar_op
,
assert
isinstance
(
topo
[
0
]
.
op
.
scalar_op
,
theano
.
scalar
.
basic
.
TrueDiv
)
theano
.
scalar
.
basic
.
TrueDiv
)
assert
len
(
topo
[
0
]
.
inputs
)
==
2
assert
len
(
topo
[
0
]
.
inputs
)
==
2
assert
(
out_dtype
==
out
.
dtype
)
assert
(
out_dtype
==
out
.
dtype
)
# test (2.0 * x) / (4.0 * y) -> (0.5 * x) / y
# test (2.0 * x) / (4.0 * y) -> (0.5 * x) / y
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
enumerate
([
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
enumerate
([
(((
2.0
*
dx
)
/
(
4.0
*
dy
)),
[
dx
,
dy
],
[
dxv
,
dyv
],
'float64'
),
(((
2.0
*
dx
)
/
(
4.0
*
dy
)),
[
dx
,
dy
],
[
dxv
,
dyv
],
'float64'
),
(((
2.0
*
fx
)
/
(
4.0
*
fy
)),
[
fx
,
fy
],
[
fxv
,
fyv
],
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
(((
2.0
*
fx
)
/
(
4.0
*
fy
)),
[
fx
,
fy
],
[
fxv
,
fyv
],
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
(((
2.0
*
dv
)
/
(
4.0
*
dy
)),
[
dv
,
dy
],
[
dvv
,
dyv
],
'float64'
),
(((
2.0
*
dv
)
/
(
4.0
*
dy
)),
[
dv
,
dy
],
[
dvv
,
dyv
],
'float64'
),
(((
2.0
*
fv
)
/
(
4.0
*
fy
)),
[
fv
,
fy
],
[
fvv
,
fyv
],
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
(((
2.0
*
fv
)
/
(
4.0
*
fy
)),
[
fv
,
fy
],
[
fvv
,
fyv
],
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
(((
2.0
*
dx
)
/
(
4.0
*
dv
)),
[
dx
,
dv
],
[
dxv
,
dvv
],
'float64'
),
(((
2.0
*
dx
)
/
(
4.0
*
dv
)),
[
dx
,
dv
],
[
dxv
,
dvv
],
'float64'
),
(((
2.0
*
fx
)
/
(
4.0
*
fv
)),
[
fx
,
fv
],
[
fxv
,
fvv
],
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
(((
2.0
*
fx
)
/
(
4.0
*
fv
)),
[
fx
,
fv
],
[
fxv
,
fvv
],
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
]):
]):
if
isinstance
(
out_dtype
,
dict
):
if
isinstance
(
out_dtype
,
dict
):
out_dtype
=
out_dtype
[
config
.
cast_policy
]
out_dtype
=
out_dtype
[
config
.
cast_policy
]
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
mode
=
mode
)
mode
=
mode
)
out
=
f
(
*
val_inputs
)
out
=
f
(
*
val_inputs
)
utt
.
assert_allclose
(
out
,
(
0.5
*
utt
.
assert_allclose
(
out
,
(
0.5
*
val_inputs
[
0
]
/
val_inputs
[
1
]))
val_inputs
[
0
]
/
val_inputs
[
1
]))
topo
=
f
.
maker
.
fgraph
.
toposort
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
2
assert
len
(
topo
)
==
2
assert
isinstance
(
topo
[
0
]
.
op
,
(
T
.
Elemwise
,
))
assert
isinstance
(
topo
[
0
]
.
op
,
(
T
.
Elemwise
,
))
assert
isinstance
(
topo
[
0
]
.
op
.
scalar_op
,
assert
isinstance
(
topo
[
0
]
.
op
.
scalar_op
,
theano
.
scalar
.
basic
.
Mul
)
theano
.
scalar
.
basic
.
Mul
)
assert
len
(
topo
[
0
]
.
inputs
)
==
2
assert
len
(
topo
[
0
]
.
inputs
)
==
2
assert
isinstance
(
topo
[
1
]
.
op
,
(
T
.
Elemwise
,
))
assert
isinstance
(
topo
[
1
]
.
op
,
(
T
.
Elemwise
,
))
assert
isinstance
(
topo
[
1
]
.
op
.
scalar_op
,
assert
isinstance
(
topo
[
1
]
.
op
.
scalar_op
,
theano
.
scalar
.
basic
.
TrueDiv
)
theano
.
scalar
.
basic
.
TrueDiv
)
assert
len
(
topo
[
1
]
.
inputs
)
==
2
assert
len
(
topo
[
1
]
.
inputs
)
==
2
assert
(
out_dtype
==
out
.
dtype
)
assert
(
out_dtype
==
out
.
dtype
)
# test 2 * x / 2 -> x
# test 2 * x / 2 -> x
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
enumerate
([
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
enumerate
([
((
2
*
dx
)
/
2
,
[
dx
],
[
dxv
],
'float64'
),
((
2
*
dx
)
/
2
,
[
dx
],
[
dxv
],
'float64'
),
((
2
*
fx
)
/
2
,
[
fx
],
[
fxv
],
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
((
2
*
fx
)
/
2
,
[
fx
],
[
fxv
],
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
((
2
*
dv
)
/
2
,
[
dv
],
[
dvv
],
'float64'
),
((
2
*
dv
)
/
2
,
[
dv
],
[
dvv
],
'float64'
),
((
2
*
fv
)
/
2
,
[
fv
],
[
fvv
],
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
((
2
*
fv
)
/
2
,
[
fv
],
[
fvv
],
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
]):
]):
if
isinstance
(
out_dtype
,
dict
):
if
isinstance
(
out_dtype
,
dict
):
out_dtype
=
out_dtype
[
config
.
cast_policy
]
out_dtype
=
out_dtype
[
config
.
cast_policy
]
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
mode
=
mode
)
mode
=
mode
)
out
=
f
(
*
val_inputs
)
out
=
f
(
*
val_inputs
)
utt
.
assert_allclose
(
out
,
val_inputs
[
0
])
utt
.
assert_allclose
(
out
,
val_inputs
[
0
])
topo
=
f
.
maker
.
fgraph
.
toposort
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
...
@@ -711,15 +675,14 @@ class test_canonize(unittest.TestCase):
...
@@ -711,15 +675,14 @@ class test_canonize(unittest.TestCase):
# test x / abs(x) -> sign(x)
# test x / abs(x) -> sign(x)
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
enumerate
([
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
enumerate
([
(
dx
/
abs
(
dx
),
[
dx
],
[
0.5
-
dxv
],
'float64'
),
(
dx
/
abs
(
dx
),
[
dx
],
[
0.5
-
dxv
],
'float64'
),
(
fx
/
abs
(
fx
),
[
fx
],
[
0.5
-
fxv
],
'float32'
),
(
fx
/
abs
(
fx
),
[
fx
],
[
0.5
-
fxv
],
'float32'
),
(
dx
/
abs
(
dx
),
[
dx
],
[
0.1
*
dxv
],
'float64'
),
(
dx
/
abs
(
dx
),
[
dx
],
[
0.1
*
dxv
],
'float64'
),
(
fx
/
abs
(
fx
),
[
fx
],
[
0.1
*
fxv
],
'float32'
),
(
fx
/
abs
(
fx
),
[
fx
],
[
0.1
*
fxv
],
'float32'
),
(
dv
/
abs
(
dv
),
[
dv
],
[
0.5
-
dvv
],
'float64'
),
(
dv
/
abs
(
dv
),
[
dv
],
[
0.5
-
dvv
],
'float64'
),
(
fv
/
abs
(
fv
),
[
fv
],
[
0.5
-
fvv
],
'float32'
),
(
fv
/
abs
(
fv
),
[
fv
],
[
0.5
-
fvv
],
'float32'
),
]):
]):
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
mode
=
mode
)
mode
=
mode
)
out
=
f
(
*
val_inputs
)
out
=
f
(
*
val_inputs
)
assert
numpy
.
all
(
numpy
.
isfinite
(
out
))
assert
numpy
.
all
(
numpy
.
isfinite
(
out
))
utt
.
assert_allclose
(
out
,
numpy
.
sign
(
val_inputs
[
0
]))
utt
.
assert_allclose
(
out
,
numpy
.
sign
(
val_inputs
[
0
]))
...
@@ -730,14 +693,14 @@ class test_canonize(unittest.TestCase):
...
@@ -730,14 +693,14 @@ class test_canonize(unittest.TestCase):
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
enumerate
([
for
id
,
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
enumerate
([
((
2
*
dx
)
/
(
3
*
abs
(
dx
)),
[
dx
],
[
0.5
-
dxv
],
'float64'
),
((
2
*
dx
)
/
(
3
*
abs
(
dx
)),
[
dx
],
[
0.5
-
dxv
],
'float64'
),
((
2
*
fx
)
/
(
3
*
abs
(
fx
)),
[
fx
],
[
0.5
-
fxv
],
((
2
*
fx
)
/
(
3
*
abs
(
fx
)),
[
fx
],
[
0.5
-
fxv
],
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
((
2
*
dx
)
/
(
3
*
abs
(
dx
)),
[
dx
],
[
0.1
*
dxv
],
'float64'
),
((
2
*
dx
)
/
(
3
*
abs
(
dx
)),
[
dx
],
[
0.1
*
dxv
],
'float64'
),
((
2
*
fx
)
/
(
3
*
abs
(
fx
)),
[
fx
],
[
0.1
*
fxv
],
((
2
*
fx
)
/
(
3
*
abs
(
fx
)),
[
fx
],
[
0.1
*
fxv
],
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
((
2
*
dv
)
/
(
3
*
abs
(
dv
)),
[
dv
],
[
0.5
-
dvv
],
'float64'
),
((
2
*
dv
)
/
(
3
*
abs
(
dv
)),
[
dv
],
[
0.5
-
dvv
],
'float64'
),
((
2
*
fv
)
/
(
3
*
abs
(
fv
)),
[
fv
],
[
0.5
-
fvv
],
((
2
*
fv
)
/
(
3
*
abs
(
fv
)),
[
fv
],
[
0.5
-
fvv
],
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
{
'custom'
:
'float32'
,
'numpy+floatX'
:
config
.
floatX
,
'numpy'
:
'float64'
}),
]):
]):
if
isinstance
(
out_dtype
,
dict
):
if
isinstance
(
out_dtype
,
dict
):
out_dtype
=
out_dtype
[
config
.
cast_policy
]
out_dtype
=
out_dtype
[
config
.
cast_policy
]
...
@@ -756,14 +719,14 @@ class test_canonize(unittest.TestCase):
...
@@ -756,14 +719,14 @@ class test_canonize(unittest.TestCase):
"""
"""
x
=
T
.
dscalar
()
x
=
T
.
dscalar
()
a
=
T
.
abs_
(
x
)
#
a = T.abs_(x)
if
theano
.
config
.
mode
==
'FAST_COMPILE'
:
if
theano
.
config
.
mode
==
'FAST_COMPILE'
:
mode
=
theano
.
compile
.
mode
.
get_mode
(
'FAST_RUN'
)
.
excluding
(
mode
=
theano
.
compile
.
mode
.
get_mode
(
'FAST_RUN'
)
.
excluding
(
"local_elemwise_fusion"
)
"local_elemwise_fusion"
)
else
:
else
:
mode
=
theano
.
compile
.
mode
.
get_default_mode
()
.
excluding
(
mode
=
theano
.
compile
.
mode
.
get_default_mode
()
.
excluding
(
"local_elemwise_fusion"
)
"local_elemwise_fusion"
)
f
=
theano
.
function
([
x
],
[(
4
*
x
)
/
abs
(
2
*
x
)],
mode
=
mode
)
f
=
theano
.
function
([
x
],
[(
4
*
x
)
/
abs
(
2
*
x
)],
mode
=
mode
)
print
(
f
.
maker
.
fgraph
.
toposort
())
print
(
f
.
maker
.
fgraph
.
toposort
())
...
@@ -804,49 +767,43 @@ class test_canonize(unittest.TestCase):
...
@@ -804,49 +767,43 @@ class test_canonize(unittest.TestCase):
dxv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float32'
)
dxv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float32'
)
dyv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float32'
)
dyv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float32'
)
dzv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float32'
)
dzv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
shp
),
dtype
=
'float32'
)
fvv
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
shp
[
0
]),
dtype
=
'float32'
)
.
reshape
(
1
,
shp
[
0
])
#
fvv = theano._asarray(numpy.random.rand(shp[0]), dtype='float32').reshape(1, shp[0])
# We must be sure that the Canonizer is working, but that we don't have other
# We must be sure that the Canonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion
mode
=
compile
.
mode
.
get_default_mode
()
mode
=
compile
.
mode
.
get_default_mode
()
opt
=
gof
.
Query
([
"canonicalize"
])
opt
=
gof
.
Query
([
"canonicalize"
])
opt
=
opt
.
excluding
(
opt
=
opt
.
excluding
(
'local_elemwise_fusion'
)
'local_elemwise_fusion'
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
# test fail!
# test fail!
# test x / y / z -> x / (y * z)
# test x / y / z -> x / (y * z)
for
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
[
for
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
[
((
dx
/
dy
)
/
dz
,
[
dx
,
dy
,
dz
],
[
dxv
,
dyv
,
dzv
],
'float64'
),
((
dx
/
dy
)
/
dz
,
[
dx
,
dy
,
dz
],
[
dxv
,
dyv
,
dzv
],
'float64'
),
((
fx
/
fy
)
/
fz
,
[
fx
,
fy
,
fz
],
[
fxv
,
fyv
,
fzv
],
'float32'
)
((
fx
/
fy
)
/
fz
,
[
fx
,
fy
,
fz
],
[
fxv
,
fyv
,
fzv
],
'float32'
)
]:
]:
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
mode
=
mode
)
mode
=
mode
)
out
=
f
(
*
val_inputs
)
out
=
f
(
*
val_inputs
)
utt
.
assert_allclose
(
out
,
val_inputs
[
0
]
/
utt
.
assert_allclose
(
out
,
val_inputs
[
0
]
/
val_inputs
[
1
]
/
val_inputs
[
2
])
val_inputs
[
1
]
/
val_inputs
[
2
])
topo
=
f
.
maker
.
fgraph
.
toposort
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
2
assert
len
(
topo
)
==
2
assert
isinstance
(
topo
[
0
]
.
op
,
(
T
.
Elemwise
,
))
assert
isinstance
(
topo
[
0
]
.
op
,
(
T
.
Elemwise
,
))
assert
isinstance
(
topo
[
0
]
.
op
.
scalar_op
,
assert
isinstance
(
topo
[
0
]
.
op
.
scalar_op
,
theano
.
scalar
.
basic
.
Inv
)
theano
.
scalar
.
basic
.
Inv
)
assert
len
(
topo
[
0
]
.
inputs
)
==
1
assert
len
(
topo
[
0
]
.
inputs
)
==
1
assert
(
out_dtype
==
out
.
dtype
)
assert
(
out_dtype
==
out
.
dtype
)
# test x / (y / z) -> (x * z) / y
# test x / (y / z) -> (x * z) / y
for
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
[
for
(
g
,
sym_inputs
,
val_inputs
,
out_dtype
)
in
[
(
dx
/
(
dy
/
dz
),
[
dx
,
dy
,
dz
],
[
dxv
,
dyv
,
dzv
],
'float64'
),
(
dx
/
(
dy
/
dz
),
[
dx
,
dy
,
dz
],
[
dxv
,
dyv
,
dzv
],
'float64'
),
(
fx
/
(
fy
/
fz
),
[
fx
,
fy
,
fz
],
[
fxv
,
fyv
,
fzv
],
'float32'
)
(
fx
/
(
fy
/
fz
),
[
fx
,
fy
,
fz
],
[
fxv
,
fyv
,
fzv
],
'float32'
)
]:
]:
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
f
=
compile
.
function
(
list
(
sym_inputs
),
g
,
mode
=
mode
)
mode
=
mode
)
out
=
f
(
*
val_inputs
)
out
=
f
(
*
val_inputs
)
utt
.
assert_allclose
(
out
,
val_inputs
[
0
]
/
(
utt
.
assert_allclose
(
out
,
val_inputs
[
0
]
/
(
val_inputs
[
1
]
/
val_inputs
[
2
]))
val_inputs
[
1
]
/
val_inputs
[
2
]))
topo
=
f
.
maker
.
fgraph
.
toposort
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
2
assert
len
(
topo
)
==
2
assert
isinstance
(
topo
[
0
]
.
op
,
(
T
.
Elemwise
,
))
assert
isinstance
(
topo
[
0
]
.
op
,
(
T
.
Elemwise
,
))
assert
isinstance
(
topo
[
0
]
.
op
.
scalar_op
,
assert
isinstance
(
topo
[
0
]
.
op
.
scalar_op
,
theano
.
scalar
.
basic
.
Inv
)
theano
.
scalar
.
basic
.
Inv
)
assert
len
(
topo
[
0
]
.
inputs
)
==
1
assert
len
(
topo
[
0
]
.
inputs
)
==
1
assert
(
out_dtype
==
out
.
dtype
)
assert
(
out_dtype
==
out
.
dtype
)
...
@@ -868,7 +825,7 @@ class test_canonize(unittest.TestCase):
...
@@ -868,7 +825,7 @@ class test_canonize(unittest.TestCase):
logging
.
getLogger
(
'theano.gof.opt'
)
.
addHandler
(
handler
)
logging
.
getLogger
(
'theano.gof.opt'
)
.
addHandler
(
handler
)
try
:
try
:
x
=
vector
()
x
=
vector
()
f
=
theano
.
function
([
x
],
x
+
numpy
.
nan
)
theano
.
function
([
x
],
x
+
numpy
.
nan
)
finally
:
finally
:
logging
.
getLogger
(
'theano.gof.opt'
)
.
removeHandler
(
handler
)
logging
.
getLogger
(
'theano.gof.opt'
)
.
removeHandler
(
handler
)
# Ideally this test would only catch the maxed out equilibrium
# Ideally this test would only catch the maxed out equilibrium
...
@@ -887,7 +844,7 @@ def test_local_merge_abs():
...
@@ -887,7 +844,7 @@ def test_local_merge_abs():
if
mode
==
"FAST_COMPILE"
:
if
mode
==
"FAST_COMPILE"
:
mode
=
"FAST_RUN"
mode
=
"FAST_RUN"
mode
=
theano
.
compile
.
mode
.
get_mode
(
mode
)
.
excluding
(
mode
=
theano
.
compile
.
mode
.
get_mode
(
mode
)
.
excluding
(
"local_elemwise_fusion"
)
"local_elemwise_fusion"
)
f
=
theano
.
function
([
y
,
z
],
(
abs
(
y
*
z
*
-
2
)),
mode
=
mode
)
f
=
theano
.
function
([
y
,
z
],
(
abs
(
y
*
z
*
-
2
)),
mode
=
mode
)
f
(
y_val
,
z_val
)
f
(
y_val
,
z_val
)
...
@@ -960,7 +917,6 @@ class test_fusion(unittest.TestCase):
...
@@ -960,7 +917,6 @@ class test_fusion(unittest.TestCase):
"""
"""
# TODO: disable the canonizer?
# TODO: disable the canonizer?
def
my_init
(
shp
,
dtype
=
'float64'
,
num
=
0
):
def
my_init
(
shp
,
dtype
=
'float64'
,
num
=
0
):
#ret = theano._asarray(numpy.random.rand(*shp),dtype=dtype)
ret
=
numpy
.
zeros
(
shp
,
dtype
=
dtype
)
+
num
ret
=
numpy
.
zeros
(
shp
,
dtype
=
dtype
)
+
num
return
ret
return
ret
fw
,
fx
,
fy
,
fz
=
[
theano
.
tensor
.
tensor
(
dtype
=
'float32'
,
fw
,
fx
,
fy
,
fz
=
[
theano
.
tensor
.
tensor
(
dtype
=
'float32'
,
...
@@ -997,7 +953,7 @@ class test_fusion(unittest.TestCase):
...
@@ -997,7 +953,7 @@ class test_fusion(unittest.TestCase):
(
fx
*
fy
+
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
*
(
fx
*
fy
+
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
*
fyv
+
fzv
,
'float32'
),
# 3
fyv
+
fzv
,
'float32'
),
# 3
(
fw
+
fx
+
fy
+
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
(
fw
+
fx
+
fy
+
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
'float32'
),
fwv
+
fxv
+
fyv
+
fzv
,
'float32'
),
((
fw
+
fx
)
+
(
fy
+
fz
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
((
fw
+
fx
)
+
(
fy
+
fz
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
'float32'
),
# 5
fwv
+
fxv
+
fyv
+
fzv
,
'float32'
),
# 5
(((
fw
+
fx
)
+
fy
)
+
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
(((
fw
+
fx
)
+
fy
)
+
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论