Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d40861ec
提交
d40861ec
authored
6月 26, 2015
作者:
Iban Harlouchet
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Flake8 for theano/tensor/opt.py
上级
34b98041
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
280 行增加
和
291 行删除
+280
-291
opt.py
theano/tensor/opt.py
+280
-290
test_flake8.py
theano/tests/test_flake8.py
+0
-1
没有找到文件。
theano/tensor/opt.py
浏览文件 @
d40861ec
...
@@ -6,8 +6,6 @@ from __future__ import print_function
...
@@ -6,8 +6,6 @@ from __future__ import print_function
# TODO: 0*x -> 0
# TODO: 0*x -> 0
import
logging
import
logging
_logger
=
logging
.
getLogger
(
'theano.tensor.opt'
)
import
itertools
import
itertools
import
operator
import
operator
import
sys
import
sys
...
@@ -34,12 +32,10 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
...
@@ -34,12 +32,10 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor
,
IncSubtensor
,
make_constant
,
Subtensor
,
IncSubtensor
,
make_constant
,
AdvancedIncSubtensor1
,
AdvancedIncSubtensor1
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor
,
AdvancedSubtensor
,
AdvancedSubtensor1
,
AdvancedSubtensor1
,
advanced_subtensor
,
advanced_subtensor
,
advanced_subtensor1
,
advanced_subtensor1
,
advanced_inc_subtensor1
,
advanced_inc_subtensor1
)
inc_subtensor
)
from
theano
import
scalar
from
theano
import
scalar
from
theano.scalar
import
basic
from
theano.scalar
import
basic
from
theano.tensor
import
basic
as
T
from
theano.tensor
import
basic
as
T
...
@@ -56,6 +52,8 @@ from theano.gof import toolbox
...
@@ -56,6 +52,8 @@ from theano.gof import toolbox
from
theano.tensor.basic
import
get_scalar_constant_value
,
ShapeError
,
NotScalarConstantError
from
theano.tensor.basic
import
get_scalar_constant_value
,
ShapeError
,
NotScalarConstantError
from
six
import
StringIO
from
six
import
StringIO
_logger
=
logging
.
getLogger
(
'theano.tensor.opt'
)
theano
.
configparser
.
AddConfigVar
(
'on_shape_error'
,
theano
.
configparser
.
AddConfigVar
(
'on_shape_error'
,
"warn: print a warning and use the default"
"warn: print a warning and use the default"
" value. raise: raise an error"
,
" value. raise: raise an error"
,
...
@@ -165,23 +163,24 @@ def broadcast_like(value, template, fgraph, dtype=None):
...
@@ -165,23 +163,24 @@ def broadcast_like(value, template, fgraph, dtype=None):
# the template may have 1s in its shape without being broadcastable
# the template may have 1s in its shape without being broadcastable
if
rval
.
broadcastable
!=
template
.
broadcastable
:
if
rval
.
broadcastable
!=
template
.
broadcastable
:
rval
=
T
.
unbroadcast
(
rval
,
*
[
i
for
i
in
xrange
(
rval
.
ndim
)
rval
=
T
.
unbroadcast
(
rval
,
*
[
i
for
i
in
xrange
(
rval
.
ndim
)
if
rval
.
broadcastable
[
i
]
if
rval
.
broadcastable
[
i
]
and
and
not
template
.
broadcastable
[
i
]])
not
template
.
broadcastable
[
i
]])
assert
rval
.
type
.
dtype
==
dtype
assert
rval
.
type
.
dtype
==
dtype
if
rval
.
type
.
broadcastable
!=
template
.
broadcastable
:
if
rval
.
type
.
broadcastable
!=
template
.
broadcastable
:
raise
AssertionError
(
"rval.type.broadcastable is "
+
raise
AssertionError
(
"rval.type.broadcastable is "
+
str
(
rval
.
type
.
broadcastable
)
+
str
(
rval
.
type
.
broadcastable
)
+
" but template.broadcastable is"
+
" but template.broadcastable is"
+
str
(
template
.
broadcastable
))
str
(
template
.
broadcastable
))
return
rval
return
rval
theano
.
configparser
.
AddConfigVar
(
'tensor.insert_inplace_optimizer_validate_nb'
,
theano
.
configparser
.
AddConfigVar
(
"-1: auto, if graph have less then 500 nodes 1, else 10"
,
'tensor.insert_inplace_optimizer_validate_nb'
,
theano
.
configparser
.
IntParam
(
-
1
),
"-1: auto, if graph have less then 500 nodes 1, else 10"
,
in_c_key
=
False
)
theano
.
configparser
.
IntParam
(
-
1
),
in_c_key
=
False
)
def
inplace_elemwise_optimizer_op
(
OP
):
def
inplace_elemwise_optimizer_op
(
OP
):
...
@@ -251,11 +250,10 @@ def inplace_elemwise_optimizer_op(OP):
...
@@ -251,11 +250,10 @@ def inplace_elemwise_optimizer_op(OP):
# target.
# target.
# Remove here as faster.
# Remove here as faster.
candidate_inputs
=
[
i
for
i
in
xrange
(
len
(
node
.
inputs
))
candidate_inputs
=
[
i
for
i
in
xrange
(
len
(
node
.
inputs
))
if
i
not
in
baseline
.
values
()
\
if
i
not
in
baseline
.
values
()
and
and
not
isinstance
(
node
.
inputs
[
i
],
not
isinstance
(
node
.
inputs
[
i
],
Constant
)
and
Constant
)
\
not
fgraph
.
destroyers
(
node
.
inputs
[
i
])
and
and
not
fgraph
.
destroyers
(
node
.
inputs
[
i
])
\
node
.
inputs
[
i
]
not
in
protected_inputs
]
and
node
.
inputs
[
i
]
not
in
protected_inputs
]
verbose
=
False
verbose
=
False
...
@@ -265,7 +263,7 @@ def inplace_elemwise_optimizer_op(OP):
...
@@ -265,7 +263,7 @@ def inplace_elemwise_optimizer_op(OP):
for
candidate_input
in
candidate_inputs
:
for
candidate_input
in
candidate_inputs
:
# remove inputs that don't have the same dtype as the output
# remove inputs that don't have the same dtype as the output
if
node
.
inputs
[
candidate_input
]
.
type
!=
node
.
outputs
[
if
node
.
inputs
[
candidate_input
]
.
type
!=
node
.
outputs
[
candidate_output
]
.
type
:
candidate_output
]
.
type
:
continue
continue
inplace_pattern
=
dict
(
baseline
)
inplace_pattern
=
dict
(
baseline
)
...
@@ -274,20 +272,20 @@ def inplace_elemwise_optimizer_op(OP):
...
@@ -274,20 +272,20 @@ def inplace_elemwise_optimizer_op(OP):
if
hasattr
(
op
.
scalar_op
,
"make_new_inplace"
):
if
hasattr
(
op
.
scalar_op
,
"make_new_inplace"
):
new_scal
=
op
.
scalar_op
.
make_new_inplace
(
new_scal
=
op
.
scalar_op
.
make_new_inplace
(
scalar
.
transfer_type
(
scalar
.
transfer_type
(
*
[
inplace_pattern
.
get
(
i
,
None
)
\
*
[
inplace_pattern
.
get
(
i
,
None
)
for
i
in
xrange
(
len
(
node
.
outputs
))]))
for
i
in
xrange
(
len
(
node
.
outputs
))]))
else
:
else
:
new_scal
=
op
.
scalar_op
.
__class__
(
new_scal
=
op
.
scalar_op
.
__class__
(
scalar
.
transfer_type
(
scalar
.
transfer_type
(
*
[
inplace_pattern
.
get
(
i
,
None
)
\
*
[
inplace_pattern
.
get
(
i
,
None
)
for
i
in
xrange
(
len
(
node
.
outputs
))]))
for
i
in
xrange
(
len
(
node
.
outputs
))]))
new_outputs
=
OP
(
new_scal
,
inplace_pattern
)(
new_outputs
=
OP
(
new_scal
,
inplace_pattern
)(
*
node
.
inputs
,
**
dict
(
return_list
=
True
))
*
node
.
inputs
,
**
dict
(
return_list
=
True
))
new_node
=
new_outputs
[
0
]
.
owner
new_node
=
new_outputs
[
0
]
.
owner
for
r
,
new_r
in
zip
(
node
.
outputs
,
new_outputs
):
for
r
,
new_r
in
zip
(
node
.
outputs
,
new_outputs
):
fgraph
.
replace
(
r
,
new_r
,
fgraph
.
replace
(
r
,
new_r
,
reason
=
"inplace_elemwise_optimizer"
)
reason
=
"inplace_elemwise_optimizer"
)
nb_change_no_validate
+=
1
nb_change_no_validate
+=
1
if
nb_change_no_validate
>=
check_each_change
:
if
nb_change_no_validate
>=
check_each_change
:
fgraph
.
validate
()
fgraph
.
validate
()
...
@@ -295,9 +293,9 @@ def inplace_elemwise_optimizer_op(OP):
...
@@ -295,9 +293,9 @@ def inplace_elemwise_optimizer_op(OP):
nb_change_no_validate
=
0
nb_change_no_validate
=
0
except
(
ValueError
,
TypeError
,
InconsistencyError
)
as
e
:
except
(
ValueError
,
TypeError
,
InconsistencyError
)
as
e
:
if
check_each_change
!=
1
and
not
raised_warning
:
if
check_each_change
!=
1
and
not
raised_warning
:
print
((
print
((
"Some inplace optimization was not "
"Some inplace optimization was not "
"performed due to unexpected error:"
),
"performed due to unexpected error:"
),
file
=
sys
.
stderr
)
file
=
sys
.
stderr
)
print
(
e
,
file
=
sys
.
stderr
)
print
(
e
,
file
=
sys
.
stderr
)
raised_warning
=
True
raised_warning
=
True
fgraph
.
revert
(
chk
)
fgraph
.
revert
(
chk
)
...
@@ -313,7 +311,8 @@ def inplace_elemwise_optimizer_op(OP):
...
@@ -313,7 +311,8 @@ def inplace_elemwise_optimizer_op(OP):
except
Exception
:
except
Exception
:
if
not
raised_warning
:
if
not
raised_warning
:
print
((
"Some inplace optimization was not "
print
((
"Some inplace optimization was not "
"performed due to unexpected error"
),
file
=
sys
.
stderr
)
"performed due to unexpected error"
),
file
=
sys
.
stderr
)
fgraph
.
revert
(
chk
)
fgraph
.
revert
(
chk
)
return
inplace_elemwise_optimizer
return
inplace_elemwise_optimizer
...
@@ -381,8 +380,8 @@ def register_specialize_device(lopt, *tags, **kwargs):
...
@@ -381,8 +380,8 @@ def register_specialize_device(lopt, *tags, **kwargs):
# Register merge_optimizer as a global opt during canonicalize
# Register merge_optimizer as a global opt during canonicalize
compile
.
optdb
[
'canonicalize'
]
.
register
(
compile
.
optdb
[
'canonicalize'
]
.
register
(
'canon_merge'
,
merge_optimizer
,
'canon_merge'
,
merge_optimizer
,
'fast_run'
,
final_opt
=
True
)
'fast_run'
,
final_opt
=
True
)
#####################
#####################
...
@@ -512,11 +511,10 @@ def local_lift_transpose_through_dot(node):
...
@@ -512,11 +511,10 @@ def local_lift_transpose_through_dot(node):
inplace. The newly-introduced transpositions are not inplace, this will
inplace. The newly-introduced transpositions are not inplace, this will
be taken care of in a later optimization phase.
be taken care of in a later optimization phase.
"""
"""
if
not
(
isinstance
(
node
.
op
,
T
.
DimShuffle
)
if
not
(
isinstance
(
node
.
op
,
T
.
DimShuffle
)
and
node
.
op
.
new_order
==
(
1
,
0
)):
and
node
.
op
.
new_order
==
(
1
,
0
)):
return
False
return
False
if
not
(
node
.
inputs
[
0
]
.
owner
if
not
(
node
.
inputs
[
0
]
.
owner
and
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
T
.
Dot
)):
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
T
.
Dot
)):
return
False
return
False
x
,
y
=
node
.
inputs
[
0
]
.
owner
.
inputs
x
,
y
=
node
.
inputs
[
0
]
.
owner
.
inputs
...
@@ -601,22 +599,19 @@ class MakeVector(T.Op):
...
@@ -601,22 +599,19 @@ class MakeVector(T.Op):
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
inputs
=
list
(
map
(
T
.
as_tensor_variable
,
inputs
))
inputs
=
list
(
map
(
T
.
as_tensor_variable
,
inputs
))
if
not
all
(
a
.
type
==
inputs
[
0
]
.
type
for
a
in
inputs
)
or
(
if
(
not
all
(
a
.
type
==
inputs
[
0
]
.
type
for
a
in
inputs
)
or
len
(
inputs
)
>
0
and
inputs
[
0
]
.
dtype
!=
self
.
dtype
):
(
len
(
inputs
)
>
0
and
inputs
[
0
]
.
dtype
!=
self
.
dtype
)):
dtype
=
theano
.
scalar
.
upcast
(
self
.
dtype
,
dtype
=
theano
.
scalar
.
upcast
(
self
.
dtype
,
*
[
i
.
dtype
for
i
in
inputs
])
*
[
i
.
dtype
for
i
in
inputs
])
# upcast the input to the determined dtype,
# upcast the input to the determined dtype,
# but don't downcast anything
# but don't downcast anything
assert
dtype
==
self
.
dtype
,
(
assert
dtype
==
self
.
dtype
,
(
"The upcast of the inputs to MakeVector should match the "
"The upcast of the inputs to MakeVector should match the "
"dtype given in __init__."
)
"dtype given in __init__."
)
if
not
all
(
self
.
dtype
==
T
.
cast
(
i
,
dtype
=
dtype
)
.
dtype
if
not
all
(
self
.
dtype
==
T
.
cast
(
i
,
dtype
=
dtype
)
.
dtype
for
i
in
inputs
):
for
i
in
inputs
):
raise
TypeError
(
"MakeVector.make_node expected inputs"
raise
TypeError
(
"MakeVector.make_node expected inputs"
" upcastable to
%
s. got
%
s"
%
(
" upcastable to
%
s. got
%
s"
%
self
.
dtype
,
(
self
.
dtype
,
str
([
i
.
dtype
for
i
in
inputs
])))
str
([
i
.
dtype
for
i
in
inputs
])
))
inputs
=
[
T
.
cast
(
i
,
dtype
=
dtype
)
for
i
in
inputs
]
inputs
=
[
T
.
cast
(
i
,
dtype
=
dtype
)
for
i
in
inputs
]
assert
all
(
self
.
dtype
==
a
.
dtype
for
a
in
inputs
)
assert
all
(
self
.
dtype
==
a
.
dtype
for
a
in
inputs
)
assert
all
(
a
.
ndim
==
0
for
a
in
inputs
)
assert
all
(
a
.
ndim
==
0
for
a
in
inputs
)
...
@@ -625,11 +620,9 @@ class MakeVector(T.Op):
...
@@ -625,11 +620,9 @@ class MakeVector(T.Op):
dtype
=
inputs
[
0
]
.
type
.
dtype
dtype
=
inputs
[
0
]
.
type
.
dtype
else
:
else
:
dtype
=
self
.
dtype
dtype
=
self
.
dtype
#bcastable = (len(inputs) == 1)
#
bcastable = (len(inputs) == 1)
bcastable
=
False
bcastable
=
False
otype
=
T
.
TensorType
(
otype
=
T
.
TensorType
(
broadcastable
=
(
bcastable
,),
dtype
=
dtype
)
broadcastable
=
(
bcastable
,),
dtype
=
dtype
)
return
T
.
Apply
(
self
,
inputs
,
[
otype
()])
return
T
.
Apply
(
self
,
inputs
,
[
otype
()])
def
__str__
(
self
):
def
__str__
(
self
):
...
@@ -700,13 +693,14 @@ class MakeVectorPrinter:
...
@@ -700,13 +693,14 @@ class MakeVectorPrinter:
if
r
.
owner
is
None
:
if
r
.
owner
is
None
:
raise
TypeError
(
"Can only print make_vector."
)
raise
TypeError
(
"Can only print make_vector."
)
elif
isinstance
(
r
.
owner
.
op
,
MakeVector
):
elif
isinstance
(
r
.
owner
.
op
,
MakeVector
):
return
"[
%
s]"
%
", "
.
join
(
pstate
.
pprinter
.
process
(
return
"[
%
s]"
%
", "
.
join
(
input
,
pstate
.
clone
(
precedence
=
1000
))
for
input
pstate
.
pprinter
.
process
(
input
,
pstate
.
clone
(
precedence
=
1000
))
in
r
.
owner
.
inputs
)
for
input
in
r
.
owner
.
inputs
)
else
:
else
:
raise
TypeError
(
"Can only print make_vector."
)
raise
TypeError
(
"Can only print make_vector."
)
T
.
pprint
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
MakeVector
),
MakeVectorPrinter
())
T
.
pprint
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
MakeVector
),
MakeVectorPrinter
())
class
ShapeFeature
(
object
):
class
ShapeFeature
(
object
):
...
@@ -843,8 +837,8 @@ class ShapeFeature(object):
...
@@ -843,8 +837,8 @@ class ShapeFeature(object):
# by always returning the same object to represent 1
# by always returning the same object to represent 1
return
self
.
lscalar_one
return
self
.
lscalar_one
if
(
type
(
s_i
)
in
integer_types
or
if
(
type
(
s_i
)
in
integer_types
or
isinstance
(
s_i
,
numpy
.
integer
)
or
isinstance
(
s_i
,
numpy
.
integer
)
or
(
isinstance
(
s_i
,
numpy
.
ndarray
)
and
s_i
.
ndim
==
0
)):
(
isinstance
(
s_i
,
numpy
.
ndarray
)
and
s_i
.
ndim
==
0
)):
# this shape is a constant
# this shape is a constant
assert
s_i
>=
0
assert
s_i
>=
0
return
T
.
constant
(
s_i
,
dtype
=
'int64'
)
return
T
.
constant
(
s_i
,
dtype
=
'int64'
)
...
@@ -859,9 +853,9 @@ class ShapeFeature(object):
...
@@ -859,9 +853,9 @@ class ShapeFeature(object):
# s_i is x.shape[i], we change it to Shape_i.
# s_i is x.shape[i], we change it to Shape_i.
if
(
s_i
.
owner
and
if
(
s_i
.
owner
and
isinstance
(
s_i
.
owner
.
op
,
Subtensor
)
and
isinstance
(
s_i
.
owner
.
op
,
Subtensor
)
and
s_i
.
owner
.
inputs
[
0
]
.
owner
and
s_i
.
owner
.
inputs
[
0
]
.
owner
and
isinstance
(
s_i
.
owner
.
inputs
[
0
]
.
owner
.
op
,
T
.
Shape
)):
isinstance
(
s_i
.
owner
.
inputs
[
0
]
.
owner
.
op
,
T
.
Shape
)):
assert
s_i
.
ndim
==
0
assert
s_i
.
ndim
==
0
assert
len
(
s_i
.
owner
.
op
.
idx_list
)
==
1
assert
len
(
s_i
.
owner
.
op
.
idx_list
)
==
1
...
@@ -883,7 +877,7 @@ class ShapeFeature(object):
...
@@ -883,7 +877,7 @@ class ShapeFeature(object):
return
s_i
return
s_i
else
:
else
:
raise
TypeError
(
'Unsupported shape element'
,
raise
TypeError
(
'Unsupported shape element'
,
s_i
,
type
(
s_i
),
getattr
(
s_i
,
'type'
,
None
))
s_i
,
type
(
s_i
),
getattr
(
s_i
,
'type'
,
None
))
def
set_shape
(
self
,
r
,
s
):
def
set_shape
(
self
,
r
,
s
):
"""Assign the shape `s` to previously un-shaped variable `r`.
"""Assign the shape `s` to previously un-shaped variable `r`.
...
@@ -910,7 +904,7 @@ class ShapeFeature(object):
...
@@ -910,7 +904,7 @@ class ShapeFeature(object):
shape_vars
=
[]
shape_vars
=
[]
for
i
in
xrange
(
r
.
ndim
):
for
i
in
xrange
(
r
.
ndim
):
if
(
hasattr
(
r
.
type
,
'broadcastable'
)
and
if
(
hasattr
(
r
.
type
,
'broadcastable'
)
and
r
.
type
.
broadcastable
[
i
]):
r
.
type
.
broadcastable
[
i
]):
shape_vars
.
append
(
self
.
lscalar_one
)
shape_vars
.
append
(
self
.
lscalar_one
)
else
:
else
:
shape_vars
.
append
(
self
.
unpack
(
s
[
i
]))
shape_vars
.
append
(
self
.
unpack
(
s
[
i
]))
...
@@ -947,8 +941,8 @@ class ShapeFeature(object):
...
@@ -947,8 +941,8 @@ class ShapeFeature(object):
self
.
set_shape
(
r
,
other_shape
)
self
.
set_shape
(
r
,
other_shape
)
return
return
if
(
other_r
.
owner
and
r
.
owner
and
if
(
other_r
.
owner
and
r
.
owner
and
other_r
.
owner
.
inputs
==
r
.
owner
.
inputs
and
other_r
.
owner
.
inputs
==
r
.
owner
.
inputs
and
other_r
.
owner
.
op
==
r
.
owner
.
op
):
other_r
.
owner
.
op
==
r
.
owner
.
op
):
# We are doing a merge. So the 2 shapes graph will be the
# We are doing a merge. So the 2 shapes graph will be the
# same. This is only a speed optimization to call
# same. This is only a speed optimization to call
# ancestors() less frequently.
# ancestors() less frequently.
...
@@ -957,10 +951,10 @@ class ShapeFeature(object):
...
@@ -957,10 +951,10 @@ class ShapeFeature(object):
# Merge other_shape with r_shape, giving the priority to other_shape
# Merge other_shape with r_shape, giving the priority to other_shape
merged_shape
=
[]
merged_shape
=
[]
for
i
,
ps
in
enumerate
(
other_shape
):
for
i
,
ps
in
enumerate
(
other_shape
):
if
(
ps
.
owner
if
(
ps
.
owner
and
and
isinstance
(
getattr
(
ps
.
owner
,
'op'
,
None
),
Shape_i
)
isinstance
(
getattr
(
ps
.
owner
,
'op'
,
None
),
Shape_i
)
and
and
ps
.
owner
.
op
.
i
==
i
ps
.
owner
.
op
.
i
==
i
and
and
ps
.
owner
.
inputs
[
0
]
in
(
r
,
other_r
)):
ps
.
owner
.
inputs
[
0
]
in
(
r
,
other_r
)):
# If other_shape[i] is uninformative, use r_shape[i].
# If other_shape[i] is uninformative, use r_shape[i].
# For now, we consider 2 cases of uninformative other_shape[i]:
# For now, we consider 2 cases of uninformative other_shape[i]:
# - Shape_i(i)(other_r);
# - Shape_i(i)(other_r);
...
@@ -1084,11 +1078,11 @@ class ShapeFeature(object):
...
@@ -1084,11 +1078,11 @@ class ShapeFeature(object):
r
in
node
.
inputs
])
r
in
node
.
inputs
])
except
NotImplementedError
as
e
:
except
NotImplementedError
as
e
:
raise
NotImplementedError
(
raise
NotImplementedError
(
'Code called by infer_shape failed raising a '
'Code called by infer_shape failed raising a '
'NotImplementedError. Raising NotImplementedError to '
'NotImplementedError. Raising NotImplementedError to '
'indicate that a shape cannot be computed is no longer '
'indicate that a shape cannot be computed is no longer '
'supported, and one should now use tensor.ShapeError '
'supported, and one should now use tensor.ShapeError '
'instead. The original exception message is:
%
s'
%
e
)
'instead. The original exception message is:
%
s'
%
e
)
except
Exception
as
e
:
except
Exception
as
e
:
msg
=
(
'Failed to infer_shape from Op
%
s.
\n
Input shapes: '
msg
=
(
'Failed to infer_shape from Op
%
s.
\n
Input shapes: '
'
%
s
\n
Exception encountered during infer_shape: '
'
%
s
\n
Exception encountered during infer_shape: '
...
@@ -1108,10 +1102,10 @@ class ShapeFeature(object):
...
@@ -1108,10 +1102,10 @@ class ShapeFeature(object):
if
len
(
o_shapes
)
!=
len
(
node
.
outputs
):
if
len
(
o_shapes
)
!=
len
(
node
.
outputs
):
raise
Exception
(
raise
Exception
(
(
'The infer_shape method for the Op "
%
s" returned a list '
+
(
'The infer_shape method for the Op "
%
s" returned a list '
+
'with the wrong number of element: len(o_shapes) =
%
d '
+
'with the wrong number of element: len(o_shapes) =
%
d '
+
' != len(node.outputs) =
%
d'
)
%
(
str
(
node
.
op
),
' != len(node.outputs) =
%
d'
)
%
(
str
(
node
.
op
),
len
(
o_shapes
),
len
(
o_shapes
),
len
(
node
.
outputs
)))
len
(
node
.
outputs
)))
# Ensure shapes are in 'int64'. This is to make sure the assert
# Ensure shapes are in 'int64'. This is to make sure the assert
# found in the `local_useless_subtensor` optimization does not fail.
# found in the `local_useless_subtensor` optimization does not fail.
...
@@ -1173,9 +1167,9 @@ class ShapeFeature(object):
...
@@ -1173,9 +1167,9 @@ class ShapeFeature(object):
# with the InputToGpuOptimizer optimizer.
# with the InputToGpuOptimizer optimizer.
continue
continue
if
(
repl
.
owner
and
if
(
repl
.
owner
and
repl
.
owner
.
inputs
[
0
]
is
shpnode
.
inputs
[
0
]
and
repl
.
owner
.
inputs
[
0
]
is
shpnode
.
inputs
[
0
]
and
isinstance
(
repl
.
owner
.
op
,
Shape_i
)
and
isinstance
(
repl
.
owner
.
op
,
Shape_i
)
and
repl
.
owner
.
op
.
i
==
shpnode
.
op
.
i
):
repl
.
owner
.
op
.
i
==
shpnode
.
op
.
i
):
# The replacement is a shape_i of the same
# The replacement is a shape_i of the same
# input. So no need to do this equivalent
# input. So no need to do this equivalent
# replacement.
# replacement.
...
@@ -1239,7 +1233,7 @@ class ShapeFeature(object):
...
@@ -1239,7 +1233,7 @@ class ShapeFeature(object):
if
not
dx
.
owner
or
not
dy
.
owner
:
if
not
dx
.
owner
or
not
dy
.
owner
:
return
False
return
False
if
(
not
isinstance
(
dx
.
owner
.
op
,
Shape_i
)
or
if
(
not
isinstance
(
dx
.
owner
.
op
,
Shape_i
)
or
not
isinstance
(
dy
.
owner
.
op
,
Shape_i
)):
not
isinstance
(
dy
.
owner
.
op
,
Shape_i
)):
return
False
return
False
opx
=
dx
.
owner
.
op
opx
=
dx
.
owner
.
op
opy
=
dy
.
owner
.
op
opy
=
dy
.
owner
.
op
...
@@ -1310,10 +1304,9 @@ def local_fill_to_alloc(node):
...
@@ -1310,10 +1304,9 @@ def local_fill_to_alloc(node):
return
return
# TODO: cut out un-necessary dimshuffles of v
# TODO: cut out un-necessary dimshuffles of v
assert
rval
[
0
]
.
type
==
node
.
outputs
[
0
]
.
type
,
(
'rval'
,
rval
[
0
]
.
type
,
assert
rval
[
0
]
.
type
==
node
.
outputs
[
0
]
.
type
,
(
'orig'
,
node
.
outputs
[
0
]
.
type
,
'rval'
,
rval
[
0
]
.
type
,
'orig'
,
node
.
outputs
[
0
]
.
type
,
'node'
,
'node'
,
node
,
node
,)
# theano.printing.debugprint(node.outputs[0], file='str'))
)
# theano.printing.debugprint(node.outputs[0], file='str'))
return
rval
return
rval
...
@@ -1404,7 +1397,7 @@ def local_subtensor_make_vector(node):
...
@@ -1404,7 +1397,7 @@ def local_subtensor_make_vector(node):
try
:
try
:
idx
,
=
node
.
op
.
idx_list
idx
,
=
node
.
op
.
idx_list
except
Exception
:
except
Exception
:
#'how can you have multiple indexes into a shape?'
#
'how can you have multiple indexes into a shape?'
raise
raise
if
isinstance
(
idx
,
(
scalar
.
Scalar
,
T
.
TensorType
)):
if
isinstance
(
idx
,
(
scalar
.
Scalar
,
T
.
TensorType
)):
...
@@ -1467,13 +1460,13 @@ def local_useless_elemwise(node):
...
@@ -1467,13 +1460,13 @@ def local_useless_elemwise(node):
if
isinstance
(
node
.
op
,
T
.
Elemwise
):
if
isinstance
(
node
.
op
,
T
.
Elemwise
):
if
node
.
op
.
scalar_op
==
theano
.
scalar
.
eq
and
len
(
node
.
inputs
)
==
2
:
if
node
.
op
.
scalar_op
==
theano
.
scalar
.
eq
and
len
(
node
.
inputs
)
==
2
:
if
node
.
inputs
[
0
]
==
node
.
inputs
[
1
]:
if
node
.
inputs
[
0
]
==
node
.
inputs
[
1
]:
# it is the same var in the graph. That will always be true
# it is the same var in the graph. That will always be true
return
[
T
.
fill
(
node
.
inputs
[
0
],
return
[
T
.
fill
(
node
.
inputs
[
0
],
T
.
constant
(
1.0
,
T
.
constant
(
1.0
,
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
))]
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
))]
elif
node
.
op
.
scalar_op
==
theano
.
scalar
.
neq
and
len
(
node
.
inputs
)
==
2
:
elif
node
.
op
.
scalar_op
==
theano
.
scalar
.
neq
and
len
(
node
.
inputs
)
==
2
:
if
node
.
inputs
[
0
]
==
node
.
inputs
[
1
]:
if
node
.
inputs
[
0
]
==
node
.
inputs
[
1
]:
# it is the same var in the graph. That will always be false
# it is the same var in the graph. That will always be false
return
[
T
.
fill
(
node
.
inputs
[
0
],
return
[
T
.
fill
(
node
.
inputs
[
0
],
T
.
constant
(
0.0
,
T
.
constant
(
0.0
,
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
))]
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
))]
...
@@ -1482,8 +1475,8 @@ def local_useless_elemwise(node):
...
@@ -1482,8 +1475,8 @@ def local_useless_elemwise(node):
elif
node
.
op
.
scalar_op
==
theano
.
scalar
.
add
and
len
(
node
.
inputs
)
==
1
:
elif
node
.
op
.
scalar_op
==
theano
.
scalar
.
add
and
len
(
node
.
inputs
)
==
1
:
return
[
node
.
inputs
[
0
]]
return
[
node
.
inputs
[
0
]]
elif
(
node
.
op
.
scalar_op
==
theano
.
scalar
.
identity
elif
(
node
.
op
.
scalar_op
==
theano
.
scalar
.
identity
and
and
len
(
node
.
inputs
)
==
1
):
len
(
node
.
inputs
)
==
1
):
return
[
node
.
inputs
[
0
]]
return
[
node
.
inputs
[
0
]]
...
@@ -1513,12 +1506,12 @@ def local_cast_cast(node):
...
@@ -1513,12 +1506,12 @@ def local_cast_cast(node):
and the first cast cause an upcast.
and the first cast cause an upcast.
"""
"""
if
(
not
isinstance
(
node
.
op
,
T
.
Elemwise
)
or
if
(
not
isinstance
(
node
.
op
,
T
.
Elemwise
)
or
not
isinstance
(
node
.
op
.
scalar_op
,
scalar
.
Cast
)):
not
isinstance
(
node
.
op
.
scalar_op
,
scalar
.
Cast
)):
return
return
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
if
(
not
x
.
owner
or
if
(
not
x
.
owner
or
not
isinstance
(
x
.
owner
.
op
,
T
.
Elemwise
)
or
not
isinstance
(
x
.
owner
.
op
,
T
.
Elemwise
)
or
not
isinstance
(
x
.
owner
.
op
.
scalar_op
,
scalar
.
Cast
)):
not
isinstance
(
x
.
owner
.
op
.
scalar_op
,
scalar
.
Cast
)):
return
return
if
node
.
op
.
scalar_op
.
o_type
==
x
.
owner
.
op
.
scalar_op
.
o_type
:
if
node
.
op
.
scalar_op
.
o_type
==
x
.
owner
.
op
.
scalar_op
.
o_type
:
return
[
x
]
return
[
x
]
...
@@ -1738,7 +1731,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
...
@@ -1738,7 +1731,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# The broadcast pattern of the ouptut must match the broadcast
# The broadcast pattern of the ouptut must match the broadcast
# pattern of at least one of the inputs.
# pattern of at least one of the inputs.
if
not
any
([
i
.
type
.
broadcastable
==
if
not
any
([
i
.
type
.
broadcastable
==
node
.
outputs
[
0
]
.
type
.
broadcastable
for
i
in
node
.
inputs
]):
node
.
outputs
[
0
]
.
type
.
broadcastable
for
i
in
node
.
inputs
]):
return
False
return
False
def
dimshuffled_alloc
(
i
):
def
dimshuffled_alloc
(
i
):
...
@@ -1749,10 +1742,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
...
@@ -1749,10 +1742,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# At least one input must have an owner that is either a AllocOP or a
# At least one input must have an owner that is either a AllocOP or a
# DimShuffleOP with an owner that is a AllocOP -- otherwise there is
# DimShuffleOP with an owner that is a AllocOP -- otherwise there is
# nothing to optimize.
# nothing to optimize.
if
not
any
([
i
.
owner
if
not
any
([
i
.
owner
and
(
isinstance
(
i
.
owner
.
op
,
AllocOP
)
or
and
(
isinstance
(
i
.
owner
.
op
,
AllocOP
)
or
dimshuffled_alloc
(
i
))
for
i
in
node
.
inputs
]):
dimshuffled_alloc
(
i
))
for
i
in
node
.
inputs
]):
return
False
return
False
# Search for input that we can use as a baseline for the dimensions.
# Search for input that we can use as a baseline for the dimensions.
...
@@ -1761,9 +1752,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
...
@@ -1761,9 +1752,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
if
i
.
type
.
broadcastable
==
node
.
outputs
[
0
]
.
type
.
broadcastable
:
if
i
.
type
.
broadcastable
==
node
.
outputs
[
0
]
.
type
.
broadcastable
:
# Prefer an input that is not a AllocOP nor a DimShuffleOP of a
# Prefer an input that is not a AllocOP nor a DimShuffleOP of a
# AllocOP so that all allocs can be optimized.
# AllocOP so that all allocs can be optimized.
if
not
(
i
.
owner
if
not
(
i
.
owner
and
(
isinstance
(
i
.
owner
.
op
,
AllocOP
)
or
and
(
isinstance
(
i
.
owner
.
op
,
AllocOP
)
dimshuffled_alloc
(
i
))):
or
dimshuffled_alloc
(
i
))):
assert_op_idx
=
idx
assert_op_idx
=
idx
break
break
...
@@ -1773,8 +1763,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
...
@@ -1773,8 +1763,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# there is more than one then do all but one. number of
# there is more than one then do all but one. number of
# inputs with alloc or dimshuffle alloc
# inputs with alloc or dimshuffle alloc
l2
=
[
i
for
i
in
node
.
inputs
l2
=
[
i
for
i
in
node
.
inputs
if
(
i
.
owner
and
(
isinstance
(
i
.
owner
.
op
,
AllocOP
)
if
(
i
.
owner
and
(
isinstance
(
i
.
owner
.
op
,
AllocOP
)
or
or
dimshuffled_alloc
(
i
)))]
dimshuffled_alloc
(
i
)))]
# If only 1 alloc or dimshuffle alloc, it is the one we
# If only 1 alloc or dimshuffle alloc, it is the one we
# will use for the shape. So no alloc would be removed.
# will use for the shape. So no alloc would be removed.
if
len
(
l2
)
>
1
:
if
len
(
l2
)
>
1
:
...
@@ -1794,14 +1784,13 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
...
@@ -1794,14 +1784,13 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
same_shape
=
node
.
fgraph
.
shape_feature
.
same_shape
same_shape
=
node
.
fgraph
.
shape_feature
.
same_shape
for
i
in
node
.
inputs
:
for
i
in
node
.
inputs
:
# Remove alloc
# Remove alloc
if
(
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
AllocOP
)
if
(
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
AllocOP
)
and
and
i
.
owner
.
inputs
[
0
]
.
type
!=
i
.
owner
.
outputs
[
0
]
.
type
):
i
.
owner
.
inputs
[
0
]
.
type
!=
i
.
owner
.
outputs
[
0
]
.
type
):
# when i.owner.inputs[0].type == i.owner.outputs[0].type we
# when i.owner.inputs[0].type == i.owner.outputs[0].type we
# will remove that alloc later
# will remove that alloc later
assert
i
.
type
.
ndim
==
cmp_op
.
ndim
assert
i
.
type
.
ndim
==
cmp_op
.
ndim
if
(
theano
.
config
.
experimental
.
local_alloc_elemwise_assert
if
(
theano
.
config
.
experimental
.
local_alloc_elemwise_assert
and
and
not
same_shape
(
i
,
cmp_op
)):
not
same_shape
(
i
,
cmp_op
)):
assert_op
=
assert_
(
assert_op
,
assert_op
=
assert_
(
assert_op
,
*
[
T
.
eq
(
i
.
shape
[
idx
],
cmp_op
.
shape
[
idx
])
*
[
T
.
eq
(
i
.
shape
[
idx
],
cmp_op
.
shape
[
idx
])
for
idx
in
xrange
(
i
.
type
.
ndim
)
for
idx
in
xrange
(
i
.
type
.
ndim
)
...
@@ -1891,7 +1880,7 @@ def local_upcast_elemwise_constant_inputs(node):
...
@@ -1891,7 +1880,7 @@ def local_upcast_elemwise_constant_inputs(node):
scalar_op
=
node
.
op
.
scalar_op
scalar_op
=
node
.
op
.
scalar_op
# print "aa", scalar_op.output_types_preference
# print "aa", scalar_op.output_types_preference
if
(
getattr
(
scalar_op
,
'output_types_preference'
,
None
)
if
(
getattr
(
scalar_op
,
'output_types_preference'
,
None
)
in
(
T
.
scal
.
upgrade_to_float
,
T
.
scal
.
upcast_out
)):
in
(
T
.
scal
.
upgrade_to_float
,
T
.
scal
.
upcast_out
)):
# this is the kind of op that we can screw with the input
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
# dtypes by upcasting explicitly
output_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
output_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
...
@@ -1909,12 +1898,12 @@ def local_upcast_elemwise_constant_inputs(node):
...
@@ -1909,12 +1898,12 @@ def local_upcast_elemwise_constant_inputs(node):
i
.
ndim
))
i
.
ndim
))
else
:
else
:
if
shape_i
is
None
:
if
shape_i
is
None
:
return
return
new_inputs
.
append
(
new_inputs
.
append
(
T
.
alloc
(
T
.
cast
(
cval_i
,
T
.
alloc
(
T
.
cast
(
cval_i
,
output_dtype
)
,
output_dtype
),
*
[
shape_i
(
d
)(
i
)
*
[
shape_i
(
d
)(
i
)
for
d
in
xrange
(
i
.
ndim
)]))
for
d
in
xrange
(
i
.
ndim
)]))
#print >> sys.stderr, "AAA",
#
print >> sys.stderr, "AAA",
#*[Shape_i(d)(i) for d in xrange(i.ndim)]
#
*[Shape_i(d)(i) for d in xrange(i.ndim)]
except
NotScalarConstantError
:
except
NotScalarConstantError
:
# for the case of a non-scalar
# for the case of a non-scalar
if
isinstance
(
i
,
T
.
TensorConstant
):
if
isinstance
(
i
,
T
.
TensorConstant
):
...
@@ -1958,7 +1947,7 @@ def local_useless_inc_subtensor(node):
...
@@ -1958,7 +1947,7 @@ def local_useless_inc_subtensor(node):
except
NotScalarConstantError
:
except
NotScalarConstantError
:
return
return
if
(
node
.
inputs
[
0
]
.
ndim
!=
node
.
inputs
[
1
]
.
ndim
or
if
(
node
.
inputs
[
0
]
.
ndim
!=
node
.
inputs
[
1
]
.
ndim
or
node
.
inputs
[
0
]
.
broadcastable
!=
node
.
inputs
[
1
]
.
broadcastable
):
node
.
inputs
[
0
]
.
broadcastable
!=
node
.
inputs
[
1
]
.
broadcastable
):
# FB: I didn't check if this case can happen, but this opt
# FB: I didn't check if this case can happen, but this opt
# don't support it.
# don't support it.
return
return
...
@@ -1994,16 +1983,16 @@ def local_set_to_inc_subtensor(node):
...
@@ -1994,16 +1983,16 @@ def local_set_to_inc_subtensor(node):
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
"""
"""
if
(
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
)
and
if
(
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
)
and
node
.
op
.
set_instead_of_inc
==
True
and
node
.
op
.
set_instead_of_inc
and
node
.
inputs
[
1
]
.
owner
and
node
.
inputs
[
1
]
.
owner
and
isinstance
(
node
.
inputs
[
1
]
.
owner
.
op
,
Elemwise
)
and
isinstance
(
node
.
inputs
[
1
]
.
owner
.
op
,
Elemwise
)
and
isinstance
(
node
.
inputs
[
1
]
.
owner
.
op
.
scalar_op
,
scalar
.
Add
)):
isinstance
(
node
.
inputs
[
1
]
.
owner
.
op
.
scalar_op
,
scalar
.
Add
)):
addn
=
node
.
inputs
[
1
]
.
owner
addn
=
node
.
inputs
[
1
]
.
owner
subn
=
None
subn
=
None
other
=
None
other
=
None
if
(
addn
.
inputs
[
0
]
.
owner
and
if
(
addn
.
inputs
[
0
]
.
owner
and
isinstance
(
addn
.
inputs
[
0
]
.
owner
.
op
,
AdvancedSubtensor1
)):
isinstance
(
addn
.
inputs
[
0
]
.
owner
.
op
,
AdvancedSubtensor1
)):
subn
=
addn
.
inputs
[
0
]
.
owner
subn
=
addn
.
inputs
[
0
]
.
owner
other
=
addn
.
inputs
[
1
]
other
=
addn
.
inputs
[
1
]
elif
(
addn
.
inputs
[
1
]
.
owner
and
elif
(
addn
.
inputs
[
1
]
.
owner
and
...
@@ -2013,7 +2002,7 @@ def local_set_to_inc_subtensor(node):
...
@@ -2013,7 +2002,7 @@ def local_set_to_inc_subtensor(node):
else
:
else
:
return
return
if
(
subn
.
inputs
[
1
]
!=
node
.
inputs
[
2
]
or
if
(
subn
.
inputs
[
1
]
!=
node
.
inputs
[
2
]
or
subn
.
inputs
[
0
]
!=
node
.
inputs
[
0
]):
subn
.
inputs
[
0
]
!=
node
.
inputs
[
0
]):
return
return
return
[
advanced_inc_subtensor1
(
node
.
inputs
[
0
],
other
,
node
.
inputs
[
2
])]
return
[
advanced_inc_subtensor1
(
node
.
inputs
[
0
],
other
,
node
.
inputs
[
2
])]
...
@@ -2030,9 +2019,9 @@ def local_useless_slice(node):
...
@@ -2030,9 +2019,9 @@ def local_useless_slice(node):
last_slice
=
len
(
slices
)
last_slice
=
len
(
slices
)
for
s
in
slices
[::
-
1
]:
for
s
in
slices
[::
-
1
]:
# check if slice and then check slice indices
# check if slice and then check slice indices
if
(
isinstance
(
s
,
slice
)
and
s
.
start
is
None
and
s
.
stop
is
None
if
(
isinstance
(
s
,
slice
)
and
s
.
start
is
None
and
s
.
stop
is
None
and
and
(
s
.
step
is
None
or
T
.
extract_constant
(
s
.
step
)
==
1
)):
(
s
.
step
is
None
or
T
.
extract_constant
(
s
.
step
)
==
1
)):
last_slice
-=
1
last_slice
-=
1
else
:
else
:
break
break
# check if we removed something
# check if we removed something
...
@@ -2098,11 +2087,10 @@ def local_useless_subtensor(node):
...
@@ -2098,11 +2087,10 @@ def local_useless_subtensor(node):
# the same underlying variable.
# the same underlying variable.
if
(
length_pos_shape_i
.
owner
and
if
(
length_pos_shape_i
.
owner
and
isinstance
(
length_pos_shape_i
.
owner
.
op
,
isinstance
(
length_pos_shape_i
.
owner
.
op
,
T
.
ScalarFromTensor
)):
T
.
ScalarFromTensor
)):
length_pos_shape_i
=
length_pos_shape_i
.
owner
.
inputs
[
0
]
length_pos_shape_i
=
length_pos_shape_i
.
owner
.
inputs
[
0
]
elif
(
length_pos
.
owner
and
elif
(
length_pos
.
owner
and
isinstance
(
length_pos
.
owner
.
op
,
isinstance
(
length_pos
.
owner
.
op
,
T
.
TensorFromScalar
)):
T
.
TensorFromScalar
)):
length_pos
=
length_pos
.
owner
.
inputs
[
0
]
length_pos
=
length_pos
.
owner
.
inputs
[
0
]
else
:
else
:
# We did not find underlying variables of the same type
# We did not find underlying variables of the same type
...
@@ -2322,8 +2310,8 @@ def merge_two_slices(slice1, len1, slice2, len2):
...
@@ -2322,8 +2310,8 @@ def merge_two_slices(slice1, len1, slice2, len2):
pn_stop
=
sl1
.
start
+
(
sl2
.
start
-
1
)
*
sl1
.
step
pn_stop
=
sl1
.
start
+
(
sl2
.
start
-
1
)
*
sl1
.
step
pn_stop
=
T
.
switch
(
T
.
and_
(
T
.
lt
(
pn_stop
,
0
),
pn_stop
=
T
.
switch
(
T
.
and_
(
T
.
lt
(
pn_stop
,
0
),
T
.
gt
(
flen
,
0
)),
T
.
gt
(
flen
,
0
)),
-
len1
-
1
,
-
len1
-
1
,
T
.
minimum
(
pn_stop
,
sl1
.
stop
))
T
.
minimum
(
pn_stop
,
sl1
.
stop
))
pn_start
=
sl1
.
start
+
(
sl2
.
stop
-
1
)
*
sl1
.
step
pn_start
=
sl1
.
start
+
(
sl2
.
stop
-
1
)
*
sl1
.
step
pn_start
=
T
.
minimum
(
pn_start
,
sl1
.
stop
)
pn_start
=
T
.
minimum
(
pn_start
,
sl1
.
stop
)
pn_start
=
T
.
maximum
(
pn_start
,
0
)
pn_start
=
T
.
maximum
(
pn_start
,
0
)
...
@@ -2345,9 +2333,8 @@ def merge_two_slices(slice1, len1, slice2, len2):
...
@@ -2345,9 +2333,8 @@ def merge_two_slices(slice1, len1, slice2, len2):
pp_start
))
pp_start
))
stop
=
T
.
switch
(
T
.
lt
(
reverse2
*
reverse1
,
0
),
stop
=
T
.
switch
(
T
.
lt
(
reverse2
*
reverse1
,
0
),
T
.
switch
(
T
.
lt
(
reverse1
,
0
),
np_stop
,
pn_stop
),
T
.
switch
(
T
.
lt
(
reverse1
,
0
),
np_stop
,
pn_stop
),
T
.
switch
(
T
.
lt
(
reverse1
,
0
),
nn_stop
,
pp_stop
T
.
switch
(
T
.
lt
(
reverse1
,
0
),
nn_stop
,
pp_stop
))
))
step
=
T
.
switch
(
T
.
lt
(
reverse2
*
reverse1
,
0
),
n_step
,
p_step
)
step
=
T
.
switch
(
T
.
lt
(
reverse2
*
reverse1
,
0
),
n_step
,
p_step
)
start
=
T
.
switch
(
T
.
le
(
flen
,
0
),
0
,
start
)
start
=
T
.
switch
(
T
.
le
(
flen
,
0
),
0
,
start
)
...
@@ -2463,7 +2450,7 @@ def local_subtensor_of_alloc(node):
...
@@ -2463,7 +2450,7 @@ def local_subtensor_of_alloc(node):
# We check that the corresponding val dimensions was
# We check that the corresponding val dimensions was
# not a broadcasted dimensions.
# not a broadcasted dimensions.
if
(
val
.
type
.
ndim
>
(
i
-
n_added_dims
)
and
if
(
val
.
type
.
ndim
>
(
i
-
n_added_dims
)
and
val
.
type
.
broadcastable
[
i
-
n_added_dims
]):
val
.
type
.
broadcastable
[
i
-
n_added_dims
]):
val_slices
.
append
(
slice
(
None
))
val_slices
.
append
(
slice
(
None
))
else
:
else
:
val_slices
.
append
(
sl
)
val_slices
.
append
(
sl
)
...
@@ -2496,8 +2483,8 @@ def local_subtensor_of_alloc(node):
...
@@ -2496,8 +2483,8 @@ def local_subtensor_of_alloc(node):
rval
[
0
]
=
theano
.
tensor
.
unbroadcast
(
rval
[
0
]
=
theano
.
tensor
.
unbroadcast
(
rval
[
0
],
rval
[
0
],
*
[
i
for
i
,
(
b1
,
b2
)
in
enumerate
(
zip
(
rval
[
0
]
.
broadcastable
,
*
[
i
for
i
,
(
b1
,
b2
)
in
enumerate
(
zip
(
rval
[
0
]
.
broadcastable
,
node
.
outputs
[
0
]
.
broadcastable
))
node
.
outputs
[
0
]
.
broadcastable
))
if
b1
and
not
b2
])
if
b1
and
not
b2
])
return
rval
return
rval
...
@@ -2518,7 +2505,7 @@ def local_subtensor_of_dot(node):
...
@@ -2518,7 +2505,7 @@ def local_subtensor_of_dot(node):
if
not
isinstance
(
node
.
op
,
Subtensor
):
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
return
if
(
not
node
.
inputs
[
0
]
.
owner
or
if
(
not
node
.
inputs
[
0
]
.
owner
or
not
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
T
.
Dot
)):
not
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
T
.
Dot
)):
return
return
# If there is other node that use the outputs of the dot
# If there is other node that use the outputs of the dot
# We don't want to compute twice the sub part.
# We don't want to compute twice the sub part.
...
@@ -2540,7 +2527,8 @@ def local_subtensor_of_dot(node):
...
@@ -2540,7 +2527,8 @@ def local_subtensor_of_dot(node):
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
# (dot also handles b.ndim < 2 as a special case)
if
b
.
ndim
>
1
and
len
(
b_indices
)
>=
b
.
ndim
-
1
:
if
b
.
ndim
>
1
and
len
(
b_indices
)
>=
b
.
ndim
-
1
:
b_indices
=
b_indices
[:
b
.
ndim
-
2
]
+
(
slice
(
None
,
None
,
None
),)
+
b_indices
[
b
.
ndim
-
2
:]
b_indices
=
(
b_indices
[:
b
.
ndim
-
2
]
+
(
slice
(
None
,
None
,
None
),)
+
b_indices
[
b
.
ndim
-
2
:])
a_sub
=
a
.
__getitem__
(
tuple
(
a_indices
))
a_sub
=
a
.
__getitem__
(
tuple
(
a_indices
))
b_sub
=
b
.
__getitem__
(
tuple
(
b_indices
))
if
b_indices
else
b
b_sub
=
b
.
__getitem__
(
tuple
(
b_indices
))
if
b_indices
else
b
...
@@ -2583,14 +2571,13 @@ def local_IncSubtensor_serialize(node):
...
@@ -2583,14 +2571,13 @@ def local_IncSubtensor_serialize(node):
"""
"""
def
movable
(
i
):
def
movable
(
i
):
# Return True iff this is a incsubtensor that we can move
# Return True iff this is a incsubtensor that we can move
return
i
.
owner
\
return
(
i
.
owner
and
and
isinstance
(
i
.
owner
.
op
,
(
IncSubtensor
,
isinstance
(
i
.
owner
.
op
,
(
IncSubtensor
,
AdvancedIncSubtensor1
,
AdvancedIncSubtensor1
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor
,))
and
))
\
i
.
type
==
o_type
and
and
i
.
type
==
o_type
\
len
(
i
.
clients
)
==
1
and
and
len
(
i
.
clients
)
==
1
\
not
i
.
owner
.
op
.
set_instead_of_inc
)
and
not
i
.
owner
.
op
.
set_instead_of_inc
if
node
.
op
==
T
.
add
:
if
node
.
op
==
T
.
add
:
o_type
=
node
.
outputs
[
0
]
.
type
o_type
=
node
.
outputs
[
0
]
.
type
...
@@ -2598,8 +2585,8 @@ def local_IncSubtensor_serialize(node):
...
@@ -2598,8 +2585,8 @@ def local_IncSubtensor_serialize(node):
movable_inputs
=
[
i
for
i
in
node
.
inputs
if
movable
(
i
)]
movable_inputs
=
[
i
for
i
in
node
.
inputs
if
movable
(
i
)]
if
movable_inputs
:
if
movable_inputs
:
new_inputs
=
[
i
for
i
in
node
.
inputs
if
not
movable
(
i
)]
\
new_inputs
=
([
i
for
i
in
node
.
inputs
if
not
movable
(
i
)]
+
+
[
mi
.
owner
.
inputs
[
0
]
for
mi
in
movable_inputs
]
[
mi
.
owner
.
inputs
[
0
]
for
mi
in
movable_inputs
])
new_add
=
T
.
add
(
*
new_inputs
)
new_add
=
T
.
add
(
*
new_inputs
)
# stack up the new incsubtensors
# stack up the new incsubtensors
...
@@ -2638,9 +2625,10 @@ def local_inplace_setsubtensor(node):
...
@@ -2638,9 +2625,10 @@ def local_inplace_setsubtensor(node):
return
[
new_node
]
return
[
new_node
]
return
False
return
False
compile
.
optdb
.
register
(
'local_inplace_setsubtensor'
,
compile
.
optdb
.
register
(
'local_inplace_setsubtensor'
,
TopoOptimizer
(
local_inplace_setsubtensor
,
TopoOptimizer
(
failure_callback
=
TopoOptimizer
.
warn_inplace
),
60
,
local_inplace_setsubtensor
,
'fast_run'
,
'inplace'
)
# DEBUG
failure_callback
=
TopoOptimizer
.
warn_inplace
),
60
,
'fast_run'
,
'inplace'
)
# DEBUG
@gof.local_optimizer
([
AdvancedIncSubtensor1
],
inplace
=
True
)
@gof.local_optimizer
([
AdvancedIncSubtensor1
],
inplace
=
True
)
...
@@ -2653,8 +2641,8 @@ def local_inplace_incsubtensor1(node):
...
@@ -2653,8 +2641,8 @@ def local_inplace_incsubtensor1(node):
return
False
return
False
compile
.
optdb
.
register
(
'local_inplace_incsubtensor1'
,
compile
.
optdb
.
register
(
'local_inplace_incsubtensor1'
,
TopoOptimizer
(
TopoOptimizer
(
local_inplace_incsubtensor1
,
local_inplace_incsubtensor1
,
failure_callback
=
TopoOptimizer
.
warn_inplace
),
failure_callback
=
TopoOptimizer
.
warn_inplace
),
60
,
'fast_run'
,
'inplace'
)
# DEBUG
60
,
'fast_run'
,
'inplace'
)
# DEBUG
...
@@ -2671,7 +2659,7 @@ def local_incsubtensor_of_zeros(node):
...
@@ -2671,7 +2659,7 @@ def local_incsubtensor_of_zeros(node):
if
(
isinstance
(
node
.
op
,
(
IncSubtensor
,
if
(
isinstance
(
node
.
op
,
(
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
))
and
AdvancedIncSubtensor1
))
and
not
node
.
op
.
set_instead_of_inc
):
not
node
.
op
.
set_instead_of_inc
):
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
y
=
node
.
inputs
[
1
]
y
=
node
.
inputs
[
1
]
replace
=
False
replace
=
False
...
@@ -2713,8 +2701,8 @@ def local_setsubtensor_of_constants(node):
...
@@ -2713,8 +2701,8 @@ def local_setsubtensor_of_constants(node):
pass
pass
if
(
replace_x
is
not
None
and
if
(
replace_x
is
not
None
and
replace_y
is
not
None
and
replace_y
is
not
None
and
replace_x
==
replace_y
):
replace_x
==
replace_y
):
return
[
x
]
return
[
x
]
else
:
else
:
return
False
return
False
...
@@ -2738,7 +2726,7 @@ def local_adv_sub1_adv_inc_sub1(node):
...
@@ -2738,7 +2726,7 @@ def local_adv_sub1_adv_inc_sub1(node):
return
return
inp
=
node
.
inputs
[
0
]
inp
=
node
.
inputs
[
0
]
if
(
not
inp
.
owner
or
if
(
not
inp
.
owner
or
not
isinstance
(
inp
.
owner
.
op
,
AdvancedIncSubtensor1
)):
not
isinstance
(
inp
.
owner
.
op
,
AdvancedIncSubtensor1
)):
return
return
idx
=
node
.
inputs
[
1
]
idx
=
node
.
inputs
[
1
]
idx2
=
inp
.
owner
.
inputs
[
2
]
idx2
=
inp
.
owner
.
inputs
[
2
]
...
@@ -2747,13 +2735,13 @@ def local_adv_sub1_adv_inc_sub1(node):
...
@@ -2747,13 +2735,13 @@ def local_adv_sub1_adv_inc_sub1(node):
if
idx
is
not
idx2
:
if
idx
is
not
idx2
:
return
return
if
(
not
inp
.
owner
.
op
.
set_instead_of_inc
and
if
(
not
inp
.
owner
.
op
.
set_instead_of_inc
and
T
.
extract_constant
(
x
)
!=
0
):
T
.
extract_constant
(
x
)
!=
0
):
return
return
cond
=
[
T
.
all
(
T
.
and_
(
T
.
lt
(
idx
,
x
.
shape
[
0
]),
cond
=
[
T
.
all
(
T
.
and_
(
T
.
lt
(
idx
,
x
.
shape
[
0
]),
T
.
ge
(
idx
,
-
x
.
shape
[
0
])))]
T
.
ge
(
idx
,
-
x
.
shape
[
0
])))]
if
not
node
.
fgraph
.
shape_feature
.
same_shape
(
idx
,
y
,
0
,
0
):
if
not
node
.
fgraph
.
shape_feature
.
same_shape
(
idx
,
y
,
0
,
0
):
cond
.
append
(
T
.
eq
(
idx
.
shape
[
0
],
y
.
shape
[
0
]))
cond
.
append
(
T
.
eq
(
idx
.
shape
[
0
],
y
.
shape
[
0
]))
y
=
Assert
(
"Bad indexing or shapes in a AdvancedIncSubtensor1 that was optimized away"
)(
y
,
*
cond
)
y
=
Assert
(
"Bad indexing or shapes in a AdvancedIncSubtensor1 "
"that was optimized away"
)(
y
,
*
cond
)
if
y
.
dtype
==
node
.
outputs
[
0
]
.
dtype
:
if
y
.
dtype
==
node
.
outputs
[
0
]
.
dtype
:
return
[
y
]
return
[
y
]
...
@@ -2828,33 +2816,34 @@ def local_useless_inc_subtensor_alloc(node):
...
@@ -2828,33 +2816,34 @@ def local_useless_inc_subtensor_alloc(node):
# Build `z_broad` explicitly to include extra implicit dimensions.
# Build `z_broad` explicitly to include extra implicit dimensions.
z_broad
=
((
True
,)
*
(
xi
.
ndim
-
z
.
ndim
)
+
z
.
broadcastable
)
z_broad
=
((
True
,)
*
(
xi
.
ndim
-
z
.
ndim
)
+
z
.
broadcastable
)
cond
=
[
# The shapes of `y` and `xi` must either agree or `y` may
cond
=
[
# also have shape equal to 1 which may be treated as a
# The shapes of `y` and `xi` must either agree or `y` may
# broadcastable dimension by the subtensor op.
# also have shape equal to 1 which may be treated as a
T
.
or_
(
T
.
eq
(
y
.
shape
[
k
],
1
),
T
.
eq
(
y
.
shape
[
k
],
xi
.
shape
[
k
]))
# broadcastable dimension by the subtensor op.
# Loop over all dimensions.
T
.
or_
(
T
.
eq
(
y
.
shape
[
k
],
1
),
T
.
eq
(
y
.
shape
[
k
],
xi
.
shape
[
k
]))
for
k
in
xrange
(
xi
.
ndim
)
# Loop over all dimensions.
# We need to check the above shapes, if
for
k
in
xrange
(
xi
.
ndim
)
# * the pre-alloc increment `z` is broadcastable in
# We need to check the above shapes, if
# dimension `k` (if it isn't, then the shapes of `z` and
# * the pre-alloc increment `z` is broadcastable in
# `y` are the same by the definition of the `Alloc` op in
# dimension `k` (if it isn't, then the shapes of `z` and
# this dimension and replacing `y` by `z` will not hide a
# `y` are the same by the definition of the `Alloc` op in
# shape error), and
# this dimension and replacing `y` by `z` will not hide a
# * `xi` and `y` do not have the same shape in dimension
# shape error), and
# `k` or we cannot infer the shape statically (if the
# * `xi` and `y` do not have the same shape in dimension
# shapes of `xi` and `y` are not the same, then replacing
# `k` or we cannot infer the shape statically (if the
# `y` by `z` will hide the shape error of `y`), and
# shapes of `xi` and `y` are not the same, then replacing
# * the shape of `y` is not equal to 1 or we cannot infer
# `y` by `z` will hide the shape error of `y`), and
# the shape statically (if the shape of `y` is equal to
# * the shape of `y` is not equal to 1 or we cannot infer
# 1, then `y` is broadcasted by the inc_subtensor op
# the shape statically (if the shape of `y` is equal to
# internally, so the shapes of `xi` and `y` do not need
# 1, then `y` is broadcasted by the inc_subtensor op
# to match in dimension `k`; else we need to check at
# internally, so the shapes of `xi` and `y` do not need
# runtime that the shape of `y` is either 1 or the same
# to match in dimension `k`; else we need to check at
# as `xi` or otherwise replacing `y` by `z` will hide a
# runtime that the shape of `y` is either 1 or the same
# shape error).
# as `xi` or otherwise replacing `y` by `z` will hide a
if
(
z_broad
[
k
]
and
# shape error).
not
same_shape
(
xi
,
y
,
dim_x
=
k
,
dim_y
=
k
)
and
if
(
z_broad
[
k
]
and
shape_of
[
y
][
k
]
!=
1
)]
not
same_shape
(
xi
,
y
,
dim_x
=
k
,
dim_y
=
k
)
and
shape_of
[
y
][
k
]
!=
1
)]
if
len
(
cond
)
>
0
:
if
len
(
cond
)
>
0
:
msg
=
'`x[i]` and `y` do not have the same shape.'
msg
=
'`x[i]` and `y` do not have the same shape.'
...
@@ -2916,7 +2905,7 @@ def local_rebroadcast_lift(node):
...
@@ -2916,7 +2905,7 @@ def local_rebroadcast_lift(node):
# compilation phase.
# compilation phase.
if
hasattr
(
input
,
'clients'
)
and
len
(
input
.
clients
)
==
1
:
if
hasattr
(
input
,
'clients'
)
and
len
(
input
.
clients
)
==
1
:
rval
=
inode
.
op
.
make_node
(
T
.
Rebroadcast
(
*
list
(
op
.
axis
.
items
()))(
rval
=
inode
.
op
.
make_node
(
T
.
Rebroadcast
(
*
list
(
op
.
axis
.
items
()))(
inode
.
inputs
[
0
]))
.
outputs
inode
.
inputs
[
0
]))
.
outputs
return
rval
return
rval
if
inode
and
isinstance
(
inode
.
op
,
T
.
Rebroadcast
):
if
inode
and
isinstance
(
inode
.
op
,
T
.
Rebroadcast
):
# the "axis" specification in the outer Rebroadcast overrides
# the "axis" specification in the outer Rebroadcast overrides
...
@@ -3031,11 +3020,11 @@ def local_join_make_vector(node):
...
@@ -3031,11 +3020,11 @@ def local_join_make_vector(node):
for
idx
in
xrange
(
2
,
len
(
node
.
inputs
)):
for
idx
in
xrange
(
2
,
len
(
node
.
inputs
)):
inp
=
node
.
inputs
[
idx
]
inp
=
node
.
inputs
[
idx
]
if
(
inp
.
owner
and
if
(
inp
.
owner
and
isinstance
(
inp
.
owner
.
op
,
MakeVector
)
and
isinstance
(
inp
.
owner
.
op
,
MakeVector
)
and
new_inputs
[
-
1
]
.
owner
and
new_inputs
[
-
1
]
.
owner
and
isinstance
(
new_inputs
[
-
1
]
.
owner
.
op
,
MakeVector
)
and
isinstance
(
new_inputs
[
-
1
]
.
owner
.
op
,
MakeVector
)
and
# MakeVector have a dtype parameter
# MakeVector have a dtype parameter
inp
.
owner
.
op
==
new_inputs
[
-
1
]
.
owner
.
op
):
inp
.
owner
.
op
==
new_inputs
[
-
1
]
.
owner
.
op
):
inps
=
new_inputs
[
-
1
]
.
owner
.
inputs
+
inp
.
owner
.
inputs
inps
=
new_inputs
[
-
1
]
.
owner
.
inputs
+
inp
.
owner
.
inputs
new_inputs
[
-
1
]
=
inp
.
owner
.
op
(
*
inps
)
new_inputs
[
-
1
]
=
inp
.
owner
.
op
(
*
inps
)
else
:
else
:
...
@@ -3059,7 +3048,7 @@ def local_remove_switch_const_cond(node):
...
@@ -3059,7 +3048,7 @@ def local_remove_switch_const_cond(node):
if cond is constant and cond != 0: left
if cond is constant and cond != 0: left
"""
"""
if
(
isinstance
(
node
.
op
,
T
.
Elemwise
)
and
if
(
isinstance
(
node
.
op
,
T
.
Elemwise
)
and
isinstance
(
node
.
op
.
scalar_op
,
scalar
.
basic
.
Switch
)):
isinstance
(
node
.
op
.
scalar_op
,
scalar
.
basic
.
Switch
)):
cond
=
T
.
extract_constant
(
node
.
inputs
[
0
],
elemwise
=
False
)
cond
=
T
.
extract_constant
(
node
.
inputs
[
0
],
elemwise
=
False
)
if
type
(
cond
)
is
numpy
.
ndarray
and
cond
.
ndim
==
0
:
if
type
(
cond
)
is
numpy
.
ndarray
and
cond
.
ndim
==
0
:
if
cond
==
0
:
if
cond
==
0
:
...
@@ -3241,9 +3230,9 @@ def local_flatten_lift(node):
...
@@ -3241,9 +3230,9 @@ def local_flatten_lift(node):
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten.
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten.
"""
"""
if
(
isinstance
(
node
.
op
,
T
.
Flatten
)
and
if
(
isinstance
(
node
.
op
,
T
.
Flatten
)
and
node
.
inputs
[
0
]
.
owner
and
node
.
inputs
[
0
]
.
owner
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
T
.
Elemwise
)
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
T
.
Elemwise
)
and
len
(
node
.
inputs
[
0
]
.
owner
.
inputs
)
==
1
):
len
(
node
.
inputs
[
0
]
.
owner
.
inputs
)
==
1
):
f
=
node
.
op
(
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
])
f
=
node
.
op
(
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
])
e
=
node
.
inputs
[
0
]
.
owner
.
op
(
f
)
e
=
node
.
inputs
[
0
]
.
owner
.
op
(
f
)
return
[
e
]
return
[
e
]
...
@@ -3290,9 +3279,9 @@ def local_reshape_lift(node):
...
@@ -3290,9 +3279,9 @@ def local_reshape_lift(node):
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a reshape.
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a reshape.
"""
"""
if
(
isinstance
(
node
.
op
,
T
.
Reshape
)
and
if
(
isinstance
(
node
.
op
,
T
.
Reshape
)
and
node
.
inputs
[
0
]
.
owner
and
node
.
inputs
[
0
]
.
owner
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
T
.
Elemwise
)
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
T
.
Elemwise
)
and
len
(
node
.
inputs
[
0
]
.
owner
.
inputs
)
==
1
):
len
(
node
.
inputs
[
0
]
.
owner
.
inputs
)
==
1
):
r
=
node
.
op
(
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
],
node
.
inputs
[
1
])
r
=
node
.
op
(
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
],
node
.
inputs
[
1
])
e
=
node
.
inputs
[
0
]
.
owner
.
op
(
r
)
e
=
node
.
inputs
[
0
]
.
owner
.
op
(
r
)
# In rare case the original broadcast was (False, True), but
# In rare case the original broadcast was (False, True), but
...
@@ -3539,7 +3528,7 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -3539,7 +3528,7 @@ class Canonizer(gof.LocalOptimizer):
return
[
input
],
[]
return
[
input
],
[]
if
input
.
owner
is
None
or
input
.
owner
.
op
not
in
[
if
input
.
owner
is
None
or
input
.
owner
.
op
not
in
[
self
.
main
,
self
.
inverse
,
self
.
reciprocal
]:
self
.
main
,
self
.
inverse
,
self
.
reciprocal
]:
if
input
.
owner
and
isinstance
(
input
.
owner
.
op
,
T
.
DimShuffle
):
if
input
.
owner
and
isinstance
(
input
.
owner
.
op
,
T
.
DimShuffle
):
# If input is a DimShuffle of some input which does
# If input is a DimShuffle of some input which does
# something like this:
# something like this:
...
@@ -3552,9 +3541,9 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -3552,9 +3541,9 @@ class Canonizer(gof.LocalOptimizer):
# the num/denum of its input
# the num/denum of its input
dsn
=
input
.
owner
# dimshuffle node
dsn
=
input
.
owner
# dimshuffle node
dsop
=
dsn
.
op
# dimshuffle op
dsop
=
dsn
.
op
# dimshuffle op
dsi0
=
dsn
.
inputs
[
0
]
# the first input of the
# dimshuffle i.e. the ndarray to
# the first input of the dimshuffle i.e. the ndarray to redim
# redim
dsi0
=
dsn
.
inputs
[
0
]
# The compatible order is a DimShuffle "new_order" of the form:
# The compatible order is a DimShuffle "new_order" of the form:
# ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim)
# ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim)
...
@@ -3566,9 +3555,9 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -3566,9 +3555,9 @@ class Canonizer(gof.LocalOptimizer):
# different numbers of dimensions (hence why we can
# different numbers of dimensions (hence why we can
# discard its information - we know we can retrieve it
# discard its information - we know we can retrieve it
# later on).
# later on).
compatible_order
=
(
'x'
,)
*
(
input
.
type
.
ndim
compatible_order
=
(
(
'x'
,)
*
-
dsi0
.
type
.
ndim
)
+
tuple
(
(
input
.
type
.
ndim
-
dsi0
.
type
.
ndim
)
+
range
(
dsi0
.
type
.
ndim
))
tuple
(
range
(
dsi0
.
type
.
ndim
)
))
if
dsop
.
new_order
==
compatible_order
:
if
dsop
.
new_order
==
compatible_order
:
# If the "new_order" is the one we recognize,
# If the "new_order" is the one we recognize,
# we return the num_denum of the dimshuffled input.
# we return the num_denum of the dimshuffled input.
...
@@ -3815,9 +3804,9 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -3815,9 +3804,9 @@ class Canonizer(gof.LocalOptimizer):
new
=
self
.
merge_num_denum
(
num
,
denum
)
new
=
self
.
merge_num_denum
(
num
,
denum
)
if
new
.
type
.
dtype
!=
out
.
type
.
dtype
:
if
new
.
type
.
dtype
!=
out
.
type
.
dtype
:
#new = T.fill(out, new)
#
new = T.fill(out, new)
elem_op
=
T
.
Elemwise
(
scalar
.
Identity
(
scalar
.
specific_out
(
elem_op
=
T
.
Elemwise
(
scalar
.
Identity
(
scalar
.
specific_out
(
getattr
(
scalar
,
out
.
type
.
dtype
))))
getattr
(
scalar
,
out
.
type
.
dtype
))))
new
=
elem_op
(
new
)
new
=
elem_op
(
new
)
assert
(
new
.
type
==
out
.
type
)
==
(
not
(
new
.
type
!=
out
.
type
))
assert
(
new
.
type
==
out
.
type
)
==
(
not
(
new
.
type
!=
out
.
type
))
...
@@ -3833,12 +3822,12 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -3833,12 +3822,12 @@ class Canonizer(gof.LocalOptimizer):
else
:
else
:
_logger
.
warning
(
' '
.
join
((
'CANONIZE FAILED: new, out = '
,
_logger
.
warning
(
' '
.
join
((
'CANONIZE FAILED: new, out = '
,
new
,
','
,
out
,
'types'
,
new
,
','
,
out
,
'types'
,
new
.
type
,
','
,
out
.
type
)))
new
.
type
,
','
,
out
.
type
)))
return
False
return
False
def
__str__
(
self
):
def
__str__
(
self
):
return
getattr
(
self
,
'name'
,
'Canonizer(
%
s,
%
s,
%
s)'
%
(
return
getattr
(
self
,
'name'
,
'Canonizer(
%
s,
%
s,
%
s)'
%
(
self
.
main
,
self
.
inverse
,
self
.
reciprocal
))
self
.
main
,
self
.
inverse
,
self
.
reciprocal
))
def
mul_calculate
(
num
,
denum
,
aslist
=
False
,
out_type
=
None
):
def
mul_calculate
(
num
,
denum
,
aslist
=
False
,
out_type
=
None
):
...
@@ -3872,7 +3861,7 @@ register_canonicalize(local_mul_canonizer, name='local_mul_canonizer')
...
@@ -3872,7 +3861,7 @@ register_canonicalize(local_mul_canonizer, name='local_mul_canonizer')
def
local_neg_to_mul
(
node
):
def
local_neg_to_mul
(
node
):
if
node
.
op
==
T
.
neg
:
if
node
.
op
==
T
.
neg
:
return
[
T
.
mul
(
numpy
.
array
(
-
1
,
dtype
=
node
.
inputs
[
0
]
.
dtype
),
return
[
T
.
mul
(
numpy
.
array
(
-
1
,
dtype
=
node
.
inputs
[
0
]
.
dtype
),
node
.
inputs
[
0
])]
node
.
inputs
[
0
])]
register_canonicalize
(
local_neg_to_mul
)
register_canonicalize
(
local_neg_to_mul
)
...
@@ -3924,10 +3913,10 @@ def local_elemwise_sub_zeros(node):
...
@@ -3924,10 +3913,10 @@ def local_elemwise_sub_zeros(node):
"""
"""
Elemwise{sub}(X,X) -> zeros_like(X)
Elemwise{sub}(X,X) -> zeros_like(X)
"""
"""
if
(
isinstance
(
node
.
op
,
T
.
Elemwise
)
if
(
isinstance
(
node
.
op
,
T
.
Elemwise
)
and
and
node
.
op
.
scalar_op
.
nin
==
2
node
.
op
.
scalar_op
.
nin
==
2
and
and
node
.
op
.
scalar_op
==
scalar
.
sub
node
.
op
.
scalar_op
==
scalar
.
sub
and
and
node
.
inputs
[
0
]
==
node
.
inputs
[
1
]):
node
.
inputs
[
0
]
==
node
.
inputs
[
1
]):
return
[
T
.
zeros_like
(
node
.
inputs
[
0
])]
return
[
T
.
zeros_like
(
node
.
inputs
[
0
])]
...
@@ -4013,9 +4002,8 @@ def local_sum_div_dimshuffle(node):
...
@@ -4013,9 +4002,8 @@ def local_sum_div_dimshuffle(node):
' to False.'
)
' to False.'
)
new_denom
=
T
.
DimShuffle
(
new_denom
=
T
.
DimShuffle
(
thing_dimshuffled
.
type
.
broadcastable
,
thing_dimshuffled
.
type
.
broadcastable
,
new_new_order
new_new_order
)(
thing_dimshuffled
)
)(
thing_dimshuffled
)
return
[
T
.
true_div
(
node
.
op
(
numerator
),
new_denom
)]
return
[
T
.
true_div
(
node
.
op
(
numerator
),
new_denom
)]
# else:
# else:
# print 'incompatible dims:', axis, new_order
# print 'incompatible dims:', axis, new_order
...
@@ -4052,8 +4040,9 @@ def local_op_of_op(node):
...
@@ -4052,8 +4040,9 @@ def local_op_of_op(node):
# We manipulate the graph so this is done to make sure the opt
# We manipulate the graph so this is done to make sure the opt
# doesn't affect other computations.
# doesn't affect other computations.
if
len
(
node_inps
.
clients
)
==
1
:
if
len
(
node_inps
.
clients
)
==
1
:
if
(
node_inps
.
owner
and
(
isinstance
(
node_inps
.
owner
.
op
,
T
.
elemwise
.
Prod
)
if
(
node_inps
.
owner
and
or
isinstance
(
node_inps
.
owner
.
op
,
T
.
elemwise
.
Sum
))):
(
isinstance
(
node_inps
.
owner
.
op
,
T
.
elemwise
.
Prod
)
or
isinstance
(
node_inps
.
owner
.
op
,
T
.
elemwise
.
Sum
))):
# check to see either the inner or outer prod is doing a
# check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it
# product over all axis, in which case we can remove it
...
@@ -4074,7 +4063,6 @@ def local_op_of_op(node):
...
@@ -4074,7 +4063,6 @@ def local_op_of_op(node):
assert
len
(
newaxis
)
==
len
(
list
(
node_inps
.
owner
.
op
.
axis
)
+
assert
len
(
newaxis
)
==
len
(
list
(
node_inps
.
owner
.
op
.
axis
)
+
list
(
node
.
op
.
axis
))
list
(
node
.
op
.
axis
))
# The old bugged logic. We keep it there to generate a warning
# The old bugged logic. We keep it there to generate a warning
# when we generated bad code.
# when we generated bad code.
alldims
=
list
(
range
(
node_inps
.
owner
.
inputs
[
0
]
.
type
.
ndim
))
alldims
=
list
(
range
(
node_inps
.
owner
.
inputs
[
0
]
.
type
.
ndim
))
...
@@ -4087,20 +4075,20 @@ def local_op_of_op(node):
...
@@ -4087,20 +4075,20 @@ def local_op_of_op(node):
if
i
not
in
alldims
]
if
i
not
in
alldims
]
if
(
theano
.
config
.
warn
.
sum_sum_bug
and
if
(
theano
.
config
.
warn
.
sum_sum_bug
and
newaxis
!=
newaxis_old
and
newaxis
!=
newaxis_old
and
len
(
newaxis
)
==
len
(
newaxis_old
)):
len
(
newaxis
)
==
len
(
newaxis_old
)):
_logger
.
warn
(
_logger
.
warn
(
"WARNING (YOUR CURRENT CODE IS FINE): Theano "
"WARNING (YOUR CURRENT CODE IS FINE): Theano "
"versions between version 9923a40c7b7a and August "
"versions between version 9923a40c7b7a and August "
"2nd, 2010 generated bugged code in this case. "
"2nd, 2010 generated bugged code in this case. "
"This happens when there are two consecutive sums "
"This happens when there are two consecutive sums "
"in the graph and the intermediate sum is not "
"in the graph and the intermediate sum is not "
"used elsewhere in the code. Some safeguard "
"used elsewhere in the code. Some safeguard "
"removed some bad code, but not in all cases. You "
"removed some bad code, but not in all cases. You "
"are in one such case. To disable this warning "
"are in one such case. To disable this warning "
"(that you can safely ignore since this bug has "
"(that you can safely ignore since this bug has "
"been fixed) set the theano flag "
"been fixed) set the theano flag "
"`warn.sum_sum_bug` to False."
)
"`warn.sum_sum_bug` to False."
)
combined
=
opt_type
(
newaxis
,
dtype
=
out_dtype
)
combined
=
opt_type
(
newaxis
,
dtype
=
out_dtype
)
return
[
combined
(
node_inps
.
owner
.
inputs
[
0
])]
return
[
combined
(
node_inps
.
owner
.
inputs
[
0
])]
...
@@ -4126,9 +4114,8 @@ def local_reduce_join(node):
...
@@ -4126,9 +4114,8 @@ def local_reduce_join(node):
"""
"""
if
(
isinstance
(
node
.
op
,
T
.
CAReduce
)
and
if
(
isinstance
(
node
.
op
,
T
.
CAReduce
)
and
node
.
inputs
[
0
]
.
owner
and
node
.
inputs
[
0
]
.
owner
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
T
.
Join
)):
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
T
.
Join
)):
join
=
node
.
inputs
[
0
]
.
owner
join
=
node
.
inputs
[
0
]
.
owner
if
T
.
extract_constant
(
join
.
inputs
[
0
])
!=
0
:
if
T
.
extract_constant
(
join
.
inputs
[
0
])
!=
0
:
return
return
...
@@ -4149,7 +4136,8 @@ def local_reduce_join(node):
...
@@ -4149,7 +4136,8 @@ def local_reduce_join(node):
if
not
inp
:
if
not
inp
:
return
return
if
(
not
isinstance
(
inp
.
op
,
DimShuffle
)
or
if
(
not
isinstance
(
inp
.
op
,
DimShuffle
)
or
inp
.
op
.
new_order
!=
(
'x'
,)
+
tuple
(
range
(
inp
.
inputs
[
0
]
.
ndim
))):
inp
.
op
.
new_order
!=
(
'x'
,)
+
tuple
(
range
(
inp
.
inputs
[
0
]
.
ndim
))):
return
return
new_inp
.
append
(
inp
.
inputs
[
0
])
new_inp
.
append
(
inp
.
inputs
[
0
])
ret
=
Elemwise
(
node
.
op
.
scalar_op
)(
*
new_inp
)
ret
=
Elemwise
(
node
.
op
.
scalar_op
)(
*
new_inp
)
...
@@ -4174,9 +4162,8 @@ def local_reduce_join(node):
...
@@ -4174,9 +4162,8 @@ def local_reduce_join(node):
'optimization, that modified the pattern '
'optimization, that modified the pattern '
'"Reduce{scalar.op}(Join(axis=0, a, b), axis=0)", '
'"Reduce{scalar.op}(Join(axis=0, a, b), axis=0)", '
'did not check the reduction axis. So if the '
'did not check the reduction axis. So if the '
'reduction axis was not 0, you got a wrong answer.'
'reduction axis was not 0, you got a wrong answer.'
))
))
return
return
# We add the new check late to don't add extra warning.
# We add the new check late to don't add extra warning.
try
:
try
:
...
@@ -4204,7 +4191,7 @@ def local_cut_useless_reduce(node):
...
@@ -4204,7 +4191,7 @@ def local_cut_useless_reduce(node):
# theano/tensor/tests/test_opt.py:T_local_reduce.test_local_reduce_broadcast_some_0
# theano/tensor/tests/test_opt.py:T_local_reduce.test_local_reduce_broadcast_some_0
# see gh-790 issue.
# see gh-790 issue.
#
#
#@register_canonicalize
#
@register_canonicalize
@register_uncanonicalize
@register_uncanonicalize
@register_specialize
@register_specialize
@gof.local_optimizer
(
ALL_REDUCE
)
@gof.local_optimizer
(
ALL_REDUCE
)
...
@@ -4258,7 +4245,7 @@ def local_opt_alloc(node):
...
@@ -4258,7 +4245,7 @@ def local_opt_alloc(node):
input
=
node_inps
.
owner
.
inputs
[
0
]
input
=
node_inps
.
owner
.
inputs
[
0
]
shapes
=
node_inps
.
owner
.
inputs
[
1
:]
shapes
=
node_inps
.
owner
.
inputs
[
1
:]
if
(
node
.
op
.
axis
is
None
or
if
(
node
.
op
.
axis
is
None
or
node
.
op
.
axis
==
tuple
(
range
(
input
.
ndim
))):
node
.
op
.
axis
==
tuple
(
range
(
input
.
ndim
))):
try
:
try
:
val
=
get_scalar_constant_value
(
input
)
val
=
get_scalar_constant_value
(
input
)
assert
val
.
size
==
1
assert
val
.
size
==
1
...
@@ -4346,7 +4333,7 @@ register_canonicalize(local_mul_zero)
...
@@ -4346,7 +4333,7 @@ register_canonicalize(local_mul_zero)
@gof.local_optimizer
([
T
.
true_div
])
@gof.local_optimizer
([
T
.
true_div
])
def
local_div_to_inv
(
node
):
def
local_div_to_inv
(
node
):
if
node
.
op
==
T
.
true_div
and
N
.
all
(
if
node
.
op
==
T
.
true_div
and
N
.
all
(
local_mul_canonizer
.
get_constant
(
node
.
inputs
[
0
])
==
1.0
):
local_mul_canonizer
.
get_constant
(
node
.
inputs
[
0
])
==
1.0
):
out
=
node
.
outputs
[
0
]
out
=
node
.
outputs
[
0
]
new_out
=
T
.
inv
(
local_mul_canonizer
.
merge_num_denum
(
node
.
inputs
[
1
:],
new_out
=
T
.
inv
(
local_mul_canonizer
.
merge_num_denum
(
node
.
inputs
[
1
:],
[]))
[]))
...
@@ -4501,7 +4488,8 @@ def local_pow_specialize_device(node):
...
@@ -4501,7 +4488,8 @@ def local_pow_specialize_device(node):
if
abs
(
y
)
>
2
:
if
abs
(
y
)
>
2
:
# We fuse all the pow together here to make
# We fuse all the pow together here to make
# compilation faster
# compilation faster
rval1
=
Elemwise
(
theano
.
scalar
.
Composite
(
rval1
=
Elemwise
(
theano
.
scalar
.
Composite
(
[
pow2_scal
[
0
]],
[
rval1_scal
]))
.
make_node
(
xsym
)
[
pow2_scal
[
0
]],
[
rval1_scal
]))
.
make_node
(
xsym
)
if
y
<
0
:
if
y
<
0
:
rval
=
[
T
.
inv
(
rval1
)]
rval
=
[
T
.
inv
(
rval1
)]
...
@@ -4566,8 +4554,8 @@ def local_mul_specialize(node):
...
@@ -4566,8 +4554,8 @@ def local_mul_specialize(node):
else
:
else
:
# The next case would cause a replace by an equivalent case.
# The next case would cause a replace by an equivalent case.
if
(
neg
and
if
(
neg
and
nb_neg_node
==
0
and
nb_neg_node
==
0
and
nb_cst
==
1
):
nb_cst
==
1
):
return
return
elif
neg
:
elif
neg
:
# Don't add an extra neg node as we can't
# Don't add an extra neg node as we can't
...
@@ -4640,8 +4628,8 @@ def check_for_x_over_absX(numerators, denominators):
...
@@ -4640,8 +4628,8 @@ def check_for_x_over_absX(numerators, denominators):
# TODO: this function should dig/search through dimshuffles
# TODO: this function should dig/search through dimshuffles
# This won't catch a dimshuffled absolute value
# This won't catch a dimshuffled absolute value
for
den
in
list
(
denominators
):
for
den
in
list
(
denominators
):
if
(
den
.
owner
and
den
.
owner
.
op
==
T
.
abs_
if
(
den
.
owner
and
den
.
owner
.
op
==
T
.
abs_
and
and
den
.
owner
.
inputs
[
0
]
in
numerators
):
den
.
owner
.
inputs
[
0
]
in
numerators
):
if
den
.
owner
.
inputs
[
0
]
.
type
.
dtype
.
startswith
(
'complex'
):
if
den
.
owner
.
inputs
[
0
]
.
type
.
dtype
.
startswith
(
'complex'
):
# TODO: Make an Op that projects a complex number to
# TODO: Make an Op that projects a complex number to
# have unit length but projects 0 to 0. That
# have unit length but projects 0 to 0. That
...
@@ -4715,8 +4703,8 @@ def local_log1p(node):
...
@@ -4715,8 +4703,8 @@ def local_log1p(node):
if
node
.
op
==
T
.
log
:
if
node
.
op
==
T
.
log
:
log_arg
,
=
node
.
inputs
log_arg
,
=
node
.
inputs
if
log_arg
.
owner
and
log_arg
.
owner
.
op
==
T
.
add
:
if
log_arg
.
owner
and
log_arg
.
owner
.
op
==
T
.
add
:
scalars
,
scalar_inputs
,
nonconsts
=
\
scalars
,
scalar_inputs
,
nonconsts
=
scalarconsts_rest
(
scalarconsts_rest
(
log_arg
.
owner
.
inputs
)
log_arg
.
owner
.
inputs
)
# scalar_inputs are potentially dimshuffled and fill'd scalars
# scalar_inputs are potentially dimshuffled and fill'd scalars
if
scalars
and
numpy
.
allclose
(
numpy
.
sum
(
scalars
),
1
):
if
scalars
and
numpy
.
allclose
(
numpy
.
sum
(
scalars
),
1
):
if
not
nonconsts
:
if
not
nonconsts
:
...
@@ -4748,7 +4736,7 @@ def local_log_add(node):
...
@@ -4748,7 +4736,7 @@ def local_log_add(node):
if
len
(
zi
)
!=
2
:
if
len
(
zi
)
!=
2
:
# -- upgrading Maximum to handle multiple inputs wasn't trivial
# -- upgrading Maximum to handle multiple inputs wasn't trivial
# TODO
# TODO
#raise NotImplementedError()
#
raise NotImplementedError()
return
return
pre_exp
=
[
x
.
owner
.
inputs
[
0
]
for
x
in
zi
pre_exp
=
[
x
.
owner
.
inputs
[
0
]
for
x
in
zi
if
x
.
owner
and
x
.
owner
.
op
==
T
.
exp
]
if
x
.
owner
and
x
.
owner
.
op
==
T
.
exp
]
...
@@ -4945,8 +4933,7 @@ def constant_folding(node):
...
@@ -4945,8 +4933,7 @@ def constant_folding(node):
storage_map
[
o
]
=
[
None
]
storage_map
[
o
]
=
[
None
]
compute_map
[
o
]
=
[
False
]
compute_map
[
o
]
=
[
False
]
if
(
hasattr
(
node
.
op
,
'python_constant_folding'
)
and
if
(
hasattr
(
node
.
op
,
'python_constant_folding'
)
and
node
.
op
.
python_constant_folding
(
node
)):
node
.
op
.
python_constant_folding
(
node
)):
old_value
=
getattr
(
node
.
op
,
'_op_use_c_code'
,
False
)
old_value
=
getattr
(
node
.
op
,
'_op_use_c_code'
,
False
)
try
:
try
:
node
.
op
.
_op_use_c_code
=
False
node
.
op
.
_op_use_c_code
=
False
...
@@ -5037,9 +5024,9 @@ register_specialize(local_one_minus_erf)
...
@@ -5037,9 +5024,9 @@ register_specialize(local_one_minus_erf)
local_one_minus_erf2
=
gof
.
PatternSub
((
T
.
add
,
local_one_minus_erf2
=
gof
.
PatternSub
((
T
.
add
,
1
,
1
,
(
T
.
mul
,
-
1
,
(
T
.
erf
,
'x'
))),
(
T
.
mul
,
-
1
,
(
T
.
erf
,
'x'
))),
(
T
.
erfc
,
'x'
),
(
T
.
erfc
,
'x'
),
allow_multiple_clients
=
True
,
allow_multiple_clients
=
True
,
name
=
'local_one_minus_erf2'
)
name
=
'local_one_minus_erf2'
)
register_canonicalize
(
local_one_minus_erf2
)
register_canonicalize
(
local_one_minus_erf2
)
register_stabilize
(
local_one_minus_erf2
)
register_stabilize
(
local_one_minus_erf2
)
register_specialize
(
local_one_minus_erf2
)
register_specialize
(
local_one_minus_erf2
)
...
@@ -5058,7 +5045,7 @@ register_canonicalize(local_one_plus_neg_erf)
...
@@ -5058,7 +5045,7 @@ register_canonicalize(local_one_plus_neg_erf)
register_stabilize
(
local_one_plus_neg_erf
)
register_stabilize
(
local_one_plus_neg_erf
)
register_specialize
(
local_one_plus_neg_erf
)
register_specialize
(
local_one_plus_neg_erf
)
#(-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize
#
(-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize
# will put the -1 as the first argument.
# will put the -1 as the first argument.
local_erf_minus_one
=
gof
.
PatternSub
((
T
.
add
,
local_erf_minus_one
=
gof
.
PatternSub
((
T
.
add
,
dict
(
pattern
=
'y'
,
constraint
=
_is_minus1
),
dict
(
pattern
=
'y'
,
constraint
=
_is_minus1
),
...
@@ -5124,7 +5111,7 @@ register_canonicalize(local_one_add_neg_erfc)
...
@@ -5124,7 +5111,7 @@ register_canonicalize(local_one_add_neg_erfc)
register_stabilize
(
local_one_add_neg_erfc
)
register_stabilize
(
local_one_add_neg_erfc
)
register_specialize
(
local_one_add_neg_erfc
)
register_specialize
(
local_one_add_neg_erfc
)
#(-1)+erfc(-x)=>erf(x)
#
(-1)+erfc(-x)=>erf(x)
local_erf_neg_minus_one
=
gof
.
PatternSub
((
T
.
add
,
local_erf_neg_minus_one
=
gof
.
PatternSub
((
T
.
add
,
dict
(
pattern
=
'y'
,
constraint
=
_is_minus1
),
dict
(
pattern
=
'y'
,
constraint
=
_is_minus1
),
(
T
.
erfc
,
(
T
.
neg
,
'x'
))),
(
T
.
erfc
,
(
T
.
neg
,
'x'
))),
...
@@ -5137,7 +5124,7 @@ register_canonicalize(local_erf_neg_minus_one)
...
@@ -5137,7 +5124,7 @@ register_canonicalize(local_erf_neg_minus_one)
register_stabilize
(
local_erf_neg_minus_one
)
register_stabilize
(
local_erf_neg_minus_one
)
register_specialize
(
local_erf_neg_minus_one
)
register_specialize
(
local_erf_neg_minus_one
)
#(-1)+erfc(-1*x)=>erf(x)
#
(-1)+erfc(-1*x)=>erf(x)
local_erf_neg_minus_one2
=
gof
.
PatternSub
((
T
.
add
,
local_erf_neg_minus_one2
=
gof
.
PatternSub
((
T
.
add
,
dict
(
pattern
=
'y'
,
constraint
=
_is_minus1
),
dict
(
pattern
=
'y'
,
constraint
=
_is_minus1
),
(
T
.
erfc
,
(
T
.
mul
,
-
1
,
'x'
))),
(
T
.
erfc
,
(
T
.
mul
,
-
1
,
'x'
))),
...
@@ -5176,8 +5163,8 @@ def local_log_erfc(node):
...
@@ -5176,8 +5163,8 @@ def local_log_erfc(node):
x
=
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
stab_value
=
(
-
x
**
2
-
T
.
log
(
x
)
-
.
5
*
T
.
log
(
numpy
.
pi
)
+
stab_value
=
(
-
x
**
2
-
T
.
log
(
x
)
-
.
5
*
T
.
log
(
numpy
.
pi
)
+
T
.
log
(
1
-
1
/
(
2
*
x
**
2
)
+
3
/
(
4
*
x
**
4
)
T
.
log
(
1
-
1
/
(
2
*
x
**
2
)
+
3
/
(
4
*
x
**
4
)
-
-
15
/
(
8
*
x
**
6
)))
15
/
(
8
*
x
**
6
)))
if
(
node
.
outputs
[
0
]
.
dtype
==
'float32'
or
if
(
node
.
outputs
[
0
]
.
dtype
==
'float32'
or
node
.
outputs
[
0
]
.
dtype
==
'float16'
):
node
.
outputs
[
0
]
.
dtype
==
'float16'
):
...
@@ -5191,8 +5178,8 @@ def local_log_erfc(node):
...
@@ -5191,8 +5178,8 @@ def local_log_erfc(node):
# Stability optimization of the grad of log(erfc(x))
# Stability optimization of the grad of log(erfc(x))
#([y*]exp(-(x**2)))/erfc(x) # The y* is optional
#
([y*]exp(-(x**2)))/erfc(x) # The y* is optional
#([y*]exp(x**2))/erfc(-x) => [y*](when x>threashold,
#
([y*]exp(x**2))/erfc(-x) => [y*](when x>threashold,
# sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)))
# sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)))
# for float64: threshold=26.63 see at the end of the fct for the explaination
# for float64: threshold=26.63 see at the end of the fct for the explaination
# for float32: threshold=9.3 see at the end of the fct for the explaination
# for float32: threshold=9.3 see at the end of the fct for the explaination
...
@@ -5226,8 +5213,8 @@ def local_grad_log_erfc_neg(node):
...
@@ -5226,8 +5213,8 @@ def local_grad_log_erfc_neg(node):
if
mul
.
owner
.
inputs
[
0
]
.
owner
or
len
(
mul
.
owner
.
inputs
)
!=
2
:
if
mul
.
owner
.
inputs
[
0
]
.
owner
or
len
(
mul
.
owner
.
inputs
)
!=
2
:
return
False
return
False
y
=
mul
.
owner
.
inputs
[
0
]
y
=
mul
.
owner
.
inputs
[
0
]
if
(
not
mul
.
owner
.
inputs
[
1
]
.
owner
if
(
not
mul
.
owner
.
inputs
[
1
]
.
owner
or
or
mul
.
owner
.
inputs
[
1
]
.
owner
.
op
!=
T
.
exp
):
mul
.
owner
.
inputs
[
1
]
.
owner
.
op
!=
T
.
exp
):
return
False
return
False
exp
=
mul
.
owner
.
inputs
[
1
]
exp
=
mul
.
owner
.
inputs
[
1
]
...
@@ -5236,8 +5223,8 @@ def local_grad_log_erfc_neg(node):
...
@@ -5236,8 +5223,8 @@ def local_grad_log_erfc_neg(node):
if
exp
.
owner
.
inputs
[
0
]
.
owner
.
op
==
T
.
neg
:
if
exp
.
owner
.
inputs
[
0
]
.
owner
.
op
==
T
.
neg
:
neg
=
exp
.
owner
.
inputs
[
0
]
neg
=
exp
.
owner
.
inputs
[
0
]
if
(
not
neg
.
owner
.
inputs
[
0
]
.
owner
if
(
not
neg
.
owner
.
inputs
[
0
]
.
owner
or
or
neg
.
owner
.
inputs
[
0
]
.
owner
.
op
!=
T
.
sqr
):
neg
.
owner
.
inputs
[
0
]
.
owner
.
op
!=
T
.
sqr
):
return
False
return
False
sqr
=
neg
.
owner
.
inputs
[
0
]
sqr
=
neg
.
owner
.
inputs
[
0
]
x
=
sqr
.
owner
.
inputs
[
0
]
x
=
sqr
.
owner
.
inputs
[
0
]
...
@@ -5279,8 +5266,8 @@ def local_grad_log_erfc_neg(node):
...
@@ -5279,8 +5266,8 @@ def local_grad_log_erfc_neg(node):
return
False
return
False
if
len
(
mul_neg
.
owner
.
inputs
)
==
2
:
if
len
(
mul_neg
.
owner
.
inputs
)
==
2
:
if
(
not
mul_neg
.
owner
.
inputs
[
1
]
.
owner
if
(
not
mul_neg
.
owner
.
inputs
[
1
]
.
owner
or
or
mul_neg
.
owner
.
inputs
[
1
]
.
owner
.
op
!=
T
.
sqr
):
mul_neg
.
owner
.
inputs
[
1
]
.
owner
.
op
!=
T
.
sqr
):
return
False
return
False
sqr
=
mul_neg
.
owner
.
inputs
[
1
]
sqr
=
mul_neg
.
owner
.
inputs
[
1
]
x
=
sqr
.
owner
.
inputs
[
0
]
x
=
sqr
.
owner
.
inputs
[
0
]
...
@@ -5292,8 +5279,8 @@ def local_grad_log_erfc_neg(node):
...
@@ -5292,8 +5279,8 @@ def local_grad_log_erfc_neg(node):
return
False
return
False
if
cst2
!=
-
1
:
if
cst2
!=
-
1
:
if
(
not
erfc_x
.
owner
or
erfc_x
.
owner
.
op
!=
T
.
mul
if
(
not
erfc_x
.
owner
or
erfc_x
.
owner
.
op
!=
T
.
mul
or
or
len
(
erfc_x
.
owner
.
inputs
)
!=
2
):
len
(
erfc_x
.
owner
.
inputs
)
!=
2
):
# todo implement that case
# todo implement that case
return
False
return
False
if
erfc_x
.
owner
.
inputs
[
1
]
is
not
mul_neg
.
owner
.
inputs
[
1
]:
if
erfc_x
.
owner
.
inputs
[
1
]
is
not
mul_neg
.
owner
.
inputs
[
1
]:
...
@@ -5324,12 +5311,12 @@ def local_grad_log_erfc_neg(node):
...
@@ -5324,12 +5311,12 @@ def local_grad_log_erfc_neg(node):
# aaron value
# aaron value
stab_value
=
(
x
*
T
.
pow
(
1
-
1
/
(
2
*
(
x
**
2
))
+
stab_value
=
(
x
*
T
.
pow
(
1
-
1
/
(
2
*
(
x
**
2
))
+
3
/
(
4
*
(
x
**
4
))
-
15
/
(
8
*
(
x
**
6
)),
-
1
)
3
/
(
4
*
(
x
**
4
))
-
15
/
(
8
*
(
x
**
6
)),
-
1
)
*
*
T
.
cast
(
T
.
sqrt
(
numpy
.
pi
),
dtype
=
x
.
dtype
))
T
.
cast
(
T
.
sqrt
(
numpy
.
pi
),
dtype
=
x
.
dtype
))
if
x
.
dtype
==
'float32'
or
x
.
dtype
==
'float16'
:
if
x
.
dtype
==
'float32'
or
x
.
dtype
==
'float16'
:
threshold
=
9.3
threshold
=
9.3
#threshold = 10.1
#
threshold = 10.1
elif
x
.
dtype
==
'float64'
:
elif
x
.
dtype
==
'float64'
:
threshold
=
26.641747557
threshold
=
26.641747557
ret
=
T
.
switch
(
x
<
threshold
,
true_div_no_mul
,
stab_value
)
*
y
ret
=
T
.
switch
(
x
<
threshold
,
true_div_no_mul
,
stab_value
)
*
y
...
@@ -5531,6 +5518,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
...
@@ -5531,6 +5518,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
if
maker
is
None
:
if
maker
is
None
:
def
maker
(
node
,
scalar_op
):
def
maker
(
node
,
scalar_op
):
return
OP
(
scalar_op
)
return
OP
(
scalar_op
)
def
local_fuse
(
node
):
def
local_fuse
(
node
):
"""
"""
As part of specialization, we fuse two consecutive elemwise Ops of the
As part of specialization, we fuse two consecutive elemwise Ops of the
...
@@ -5598,13 +5586,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
...
@@ -5598,13 +5586,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
# If a variable is used as multiple into to the same node,
# If a variable is used as multiple into to the same node,
# we still want to fusion. So we take the set.
# we still want to fusion. So we take the set.
if
(
i
.
owner
and
if
(
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
OP
)
and
isinstance
(
i
.
owner
.
op
,
OP
)
and
len
(
set
([
n
for
n
,
idx
in
i
.
clients
]))
==
1
and
len
(
set
([
n
for
n
,
idx
in
i
.
clients
]))
==
1
and
# Do not merge elemwise that don't have the same
# Do not merge elemwise that don't have the same
# broadcastable pattern to don't redo duplicate
# broadcastable pattern to don't redo duplicate
# computation due to broadcast.
# computation due to broadcast.
i
.
owner
.
outputs
[
0
]
.
broadcastable
==
node
.
outputs
[
0
]
.
broadcastable
):
i
.
owner
.
outputs
[
0
]
.
broadcastable
==
node
.
outputs
[
0
]
.
broadcastable
):
do_fusion
=
True
do_fusion
=
True
try
:
try
:
tmp_s_input
=
[]
tmp_s_input
=
[]
...
@@ -5840,14 +5828,14 @@ def local_add_mul_fusion(node):
...
@@ -5840,14 +5828,14 @@ def local_add_mul_fusion(node):
"""
"""
if
(
not
isinstance
(
node
.
op
,
Elemwise
)
or
if
(
not
isinstance
(
node
.
op
,
Elemwise
)
or
not
isinstance
(
node
.
op
.
scalar_op
,
(
scalar
.
Add
,
scalar
.
Mul
))):
not
isinstance
(
node
.
op
.
scalar_op
,
(
scalar
.
Add
,
scalar
.
Mul
))):
return
False
return
False
s_op
=
node
.
op
.
scalar_op
.
__class__
s_op
=
node
.
op
.
scalar_op
.
__class__
for
inp
in
node
.
inputs
:
for
inp
in
node
.
inputs
:
if
(
inp
.
owner
and
if
(
inp
.
owner
and
isinstance
(
inp
.
owner
.
op
,
Elemwise
)
and
isinstance
(
inp
.
owner
.
op
,
Elemwise
)
and
isinstance
(
inp
.
owner
.
op
.
scalar_op
,
s_op
)):
isinstance
(
inp
.
owner
.
op
.
scalar_op
,
s_op
)):
l
=
list
(
node
.
inputs
)
l
=
list
(
node
.
inputs
)
l
.
remove
(
inp
)
l
.
remove
(
inp
)
return
[
node
.
op
(
*
(
l
+
inp
.
owner
.
inputs
))]
return
[
node
.
op
(
*
(
l
+
inp
.
owner
.
inputs
))]
...
@@ -5882,13 +5870,15 @@ else:
...
@@ -5882,13 +5870,15 @@ else:
# just returns the input, it should be removed from the graph to
# just returns the input, it should be removed from the graph to
# make sure all possible optimizations can be applied.
# make sure all possible optimizations can be applied.
register_canonicalize
(
gof
.
OpRemove
(
theano
.
gradient
.
consider_constant_
),
register_canonicalize
(
gof
.
OpRemove
(
theano
.
gradient
.
consider_constant_
),
'fast_compile'
,
'fast_run'
,
name
=
'remove_consider_constant'
)
'fast_compile'
,
'fast_run'
,
name
=
'remove_consider_constant'
)
register_canonicalize
(
gof
.
OpRemove
(
theano
.
gradient
.
zero_grad_
),
register_canonicalize
(
gof
.
OpRemove
(
theano
.
gradient
.
zero_grad_
),
'fast_compile'
,
'fast_run'
,
name
=
'remove_zero_grad'
)
'fast_compile'
,
'fast_run'
,
name
=
'remove_zero_grad'
)
register_canonicalize
(
gof
.
OpRemove
(
theano
.
gradient
.
disconnected_grad_
),
register_canonicalize
(
gof
.
OpRemove
(
theano
.
gradient
.
disconnected_grad_
),
'fast_compile'
,
'fast_run'
,
name
=
'remove_disconnected_grad'
)
'fast_compile'
,
'fast_run'
,
name
=
'remove_disconnected_grad'
)
@register_canonicalize
@register_canonicalize
...
...
theano/tests/test_flake8.py
浏览文件 @
d40861ec
...
@@ -63,7 +63,6 @@ whitelist_flake8 = [
...
@@ -63,7 +63,6 @@ whitelist_flake8 = [
"tensor/sort.py"
,
"tensor/sort.py"
,
"tensor/__init__.py"
,
"tensor/__init__.py"
,
"tensor/opt_uncanonicalize.py"
,
"tensor/opt_uncanonicalize.py"
,
"tensor/opt.py"
,
"tensor/blas.py"
,
"tensor/blas.py"
,
"tensor/extra_ops.py"
,
"tensor/extra_ops.py"
,
"tensor/nlinalg.py"
,
"tensor/nlinalg.py"
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论