Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f6141a60
提交
f6141a60
authored
8月 06, 2015
作者:
Iban Harlouchet
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
numpydoc for theano/tensor/opt.py
上级
7f312182
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
254 行增加
和
134 行删除
+254
-134
opt.py
theano/tensor/opt.py
+254
-134
没有找到文件。
theano/tensor/opt.py
浏览文件 @
f6141a60
"""
"""
Tensor optimizations addressing the ops in basic.py
Tensor optimizations addressing the ops in basic.py
.
"""
"""
from
__future__
import
print_function
from
__future__
import
print_function
# TODO: intelligent merge for mul/add
# TODO: intelligent merge for mul/add
...
@@ -68,15 +68,20 @@ def copy_stack_trace(from_var, to_var):
...
@@ -68,15 +68,20 @@ def copy_stack_trace(from_var, to_var):
Copies the stack trace from one or more tensor variables to
Copies the stack trace from one or more tensor variables to
one or more tensor variables.
one or more tensor variables.
:param from_var: tensor variable or list of tensor variables to
Parameters
copy stack traces from.
----------
:param to_var: tensor variable or list of tensor variables to
from_var
copy stack traces to.
Tensor variable or list of tensor variables to copy stack traces from.
to_var
Tensor variable or list of tensor variables to copy stack traces to.
.. note:: The stacktrace is assumed to be of the form of a list of lists
Notes
-----
The stacktrace is assumed to be of the form of a list of lists
of tuples. Each tuple contains the filename, line number, function name
of tuples. Each tuple contains the filename, line number, function name
and so on. Each list of tuples contains the truples belonging to a
and so on. Each list of tuples contains the truples belonging to a
particular variable.
particular variable.
"""
"""
# Store stack traces from from_var
# Store stack traces from from_var
...
@@ -151,11 +156,18 @@ def _fill_chain(new_out, orig_inputs):
...
@@ -151,11 +156,18 @@ def _fill_chain(new_out, orig_inputs):
def
encompasses_broadcastable
(
b1
,
b2
):
def
encompasses_broadcastable
(
b1
,
b2
):
"""
"""
Returns True if the broadcastable patterns b1 and b2 are such that b2 is
Parameters
----------
b1
The broadcastable attribute of a tensor type.
b2
The broadcastable attribute of a tensor type.
Returns
-------
True if the broadcastable patterns b1 and b2 are such that b2 is
broadcasted to b1's shape and not the opposite.
broadcasted to b1's shape and not the opposite.
:param b1: the broadcastable attribute of a tensor type
:param b2: the broadcastable attribute of a tensor type
"""
"""
if
len
(
b1
)
<
len
(
b2
):
if
len
(
b1
)
<
len
(
b2
):
return
False
return
False
...
@@ -184,7 +196,8 @@ def scalarconsts_rest(inputs):
...
@@ -184,7 +196,8 @@ def scalarconsts_rest(inputs):
def
broadcast_like
(
value
,
template
,
fgraph
,
dtype
=
None
):
def
broadcast_like
(
value
,
template
,
fgraph
,
dtype
=
None
):
"""Return a Variable with the same shape and dtype as the template,
"""
Return a Variable with the same shape and dtype as the template,
filled by broadcasting value through it. `value` will be cast as
filled by broadcasting value through it. `value` will be cast as
necessary.
necessary.
...
@@ -240,9 +253,11 @@ def inplace_elemwise_optimizer_op(OP):
...
@@ -240,9 +253,11 @@ def inplace_elemwise_optimizer_op(OP):
see if it can operate inplace on that input. If so, makes the
see if it can operate inplace on that input. If so, makes the
change and go to the next output or Broadcast Op.
change and go to the next output or Broadcast Op.
Examples:
Examples
x + y + z -> x += y += z
--------
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
"""
# We should not validate too often as this takes too much time to
# We should not validate too often as this takes too much time to
# execute!
# execute!
...
@@ -507,6 +522,7 @@ def local_dimshuffle_lift(node):
...
@@ -507,6 +522,7 @@ def local_dimshuffle_lift(node):
After this transform, clusters of Elemwise operations are
After this transform, clusters of Elemwise operations are
void of DimShuffle operations.
void of DimShuffle operations.
"""
"""
op
=
node
.
op
op
=
node
.
op
if
not
isinstance
(
op
,
DimShuffle
):
if
not
isinstance
(
op
,
DimShuffle
):
...
@@ -556,6 +572,7 @@ def local_lift_transpose_through_dot(node):
...
@@ -556,6 +572,7 @@ def local_lift_transpose_through_dot(node):
The transformation should be apply whether or not the transpose is
The transformation should be apply whether or not the transpose is
inplace. The newly-introduced transpositions are not inplace, this will
inplace. The newly-introduced transpositions are not inplace, this will
be taken care of in a later optimization phase.
be taken care of in a later optimization phase.
"""
"""
if
not
(
isinstance
(
node
.
op
,
T
.
DimShuffle
)
and
node
.
op
.
new_order
==
(
1
,
0
)):
if
not
(
isinstance
(
node
.
op
,
T
.
DimShuffle
)
and
node
.
op
.
new_order
==
(
1
,
0
)):
return
False
return
False
...
@@ -639,11 +656,12 @@ def local_scalar_tensor_scalar(node):
...
@@ -639,11 +656,12 @@ def local_scalar_tensor_scalar(node):
class
MakeVector
(
T
.
Op
):
class
MakeVector
(
T
.
Op
):
"""Concatenate a number of scalars together into a vector
"""Concatenate a number of scalars together into a vector
.
This is a simple version of stack() that introduces far less cruft
This is a simple version of stack() that introduces far less cruft
into the graph. Should work with 0 inputs. The constant_folding
into the graph. Should work with 0 inputs. The constant_folding
optimization will remove it.
optimization will remove it.
"""
"""
__props__
=
(
"dtype"
,)
__props__
=
(
"dtype"
,)
...
@@ -755,7 +773,7 @@ T.pprint.assign(lambda pstate, r: r.owner and
...
@@ -755,7 +773,7 @@ T.pprint.assign(lambda pstate, r: r.owner and
class
ShapeFeature
(
object
):
class
ShapeFeature
(
object
):
"""Graph optimizer for removing all calls to shape()
"""Graph optimizer for removing all calls to shape()
.
This optimizer replaces all Shapes and Subtensors of Shapes with
This optimizer replaces all Shapes and Subtensors of Shapes with
Shape_i and MakeVector Ops.
Shape_i and MakeVector Ops.
...
@@ -791,7 +809,6 @@ class ShapeFeature(object):
...
@@ -791,7 +809,6 @@ class ShapeFeature(object):
For example the infer_shape for a matrix-matrix product would accept
For example the infer_shape for a matrix-matrix product would accept
input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),).
input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),).
Inferring the shape of internal nodes in the graph is important
Inferring the shape of internal nodes in the graph is important
for doing size-driven optimizations. If we know how big various
for doing size-driven optimizations. If we know how big various
intermediate results will be, we can estimate the cost of many Ops
intermediate results will be, we can estimate the cost of many Ops
...
@@ -800,18 +817,18 @@ class ShapeFeature(object):
...
@@ -800,18 +817,18 @@ class ShapeFeature(object):
In cases where you cannot figure out the shape, raise a ShapeError.
In cases where you cannot figure out the shape, raise a ShapeError.
.. note::
Notes
-----
Right now there is only the ConvOp that could really take
Right now there is only the ConvOp that could really take
advantage of this shape inference, but it is worth it even
advantage of this shape inference, but it is worth it even
just for the ConvOp. All that's necessary to do shape
just for the ConvOp. All that's necessary to do shape
inference is 1) to mark shared inputs as having a particular
inference is 1) to mark shared inputs as having a particular
shape, either via a .tag or some similar hacking; and 2) to
shape, either via a .tag or some similar hacking; and 2) to
add an optional Param() argument to promise that inputs will
add an optional Param() argument to promise that inputs will
have a certain shape (or even to have certain shapes in
have a certain shape (or even to have certain shapes in
certain dimensions). We can't automatically infer the shape of
certain dimensions). We can't automatically infer the shape of
shared variables as they can change of shape during the
shared variables as they can change of shape during the
execution by default. (NOT IMPLEMENTED YET, BUT IS IN TRAC)
execution by default. (NOT IMPLEMENTED YET, BUT IS IN TRAC)
Using Shape information in Optimizations
Using Shape information in Optimizations
...
@@ -842,7 +859,7 @@ class ShapeFeature(object):
...
@@ -842,7 +859,7 @@ class ShapeFeature(object):
"""
"""
def
shape_ir
(
self
,
i
,
r
):
def
shape_ir
(
self
,
i
,
r
):
"""Return symbolic r.shape[i] for tensor variable r, int i"""
"""Return symbolic r.shape[i] for tensor variable r, int i
.
"""
if
hasattr
(
r
.
type
,
"broadcastable"
)
and
r
.
type
.
broadcastable
[
i
]:
if
hasattr
(
r
.
type
,
"broadcastable"
)
and
r
.
type
.
broadcastable
[
i
]:
return
self
.
lscalar_one
return
self
.
lscalar_one
else
:
else
:
...
@@ -855,7 +872,7 @@ class ShapeFeature(object):
...
@@ -855,7 +872,7 @@ class ShapeFeature(object):
return
s
return
s
def
shape_tuple
(
self
,
r
):
def
shape_tuple
(
self
,
r
):
"""Return a tuple of symbolic shape vars for tensor variable r"""
"""Return a tuple of symbolic shape vars for tensor variable r
.
"""
if
not
hasattr
(
r
,
'ndim'
):
if
not
hasattr
(
r
,
'ndim'
):
# This happen for NoneConst.
# This happen for NoneConst.
return
None
return
None
...
@@ -867,6 +884,7 @@ class ShapeFeature(object):
...
@@ -867,6 +884,7 @@ class ShapeFeature(object):
This function is used for Ops that don't implement infer_shape.
This function is used for Ops that don't implement infer_shape.
Ops that do implement infer_shape should use the i_shapes parameter,
Ops that do implement infer_shape should use the i_shapes parameter,
but this default implementation ignores it.
but this default implementation ignores it.
"""
"""
rval
=
[]
rval
=
[]
for
r
in
node
.
outputs
:
for
r
in
node
.
outputs
:
...
@@ -880,6 +898,7 @@ class ShapeFeature(object):
...
@@ -880,6 +898,7 @@ class ShapeFeature(object):
"""Return a symbolic integer scalar for the shape element s_i.
"""Return a symbolic integer scalar for the shape element s_i.
The s_i argument was produced by the infer_shape() of an Op subclass.
The s_i argument was produced by the infer_shape() of an Op subclass.
"""
"""
# unpack the s_i that the Op returned
# unpack the s_i that the Op returned
assert
s_i
is
not
None
assert
s_i
is
not
None
...
@@ -933,8 +952,11 @@ class ShapeFeature(object):
...
@@ -933,8 +952,11 @@ class ShapeFeature(object):
def
set_shape
(
self
,
r
,
s
):
def
set_shape
(
self
,
r
,
s
):
"""Assign the shape `s` to previously un-shaped variable `r`.
"""Assign the shape `s` to previously un-shaped variable `r`.
:type r: a variable
Parameters
:type s: None or a tuple of symbolic integers
----------
r : a variable
s : None or a tuple of symbolic integers
"""
"""
assert
r
not
in
self
.
shape_of
,
'r already in shape_of'
assert
r
not
in
self
.
shape_of
,
'r already in shape_of'
if
s
is
None
:
if
s
is
None
:
...
@@ -972,11 +994,12 @@ class ShapeFeature(object):
...
@@ -972,11 +994,12 @@ class ShapeFeature(object):
self
.
shape_of_reverse_index
.
setdefault
(
sv
,
set
())
.
add
(
r
)
self
.
shape_of_reverse_index
.
setdefault
(
sv
,
set
())
.
add
(
r
)
def
update_shape
(
self
,
r
,
other_r
):
def
update_shape
(
self
,
r
,
other_r
):
'''
Replace shape of r by shape of other_r.
"""
Replace shape of r by shape of other_r.
If, on some dimensions, the shape of other_r is not informative,
If, on some dimensions, the shape of other_r is not informative,
keep the shape of r on those dimensions.
keep the shape of r on those dimensions.
'''
"""
# other_r should already have a shape
# other_r should already have a shape
assert
other_r
in
self
.
shape_of
,
(
'other_r not in shape_of'
,
other_r
)
assert
other_r
in
self
.
shape_of
,
(
'other_r not in shape_of'
,
other_r
)
other_shape
=
self
.
shape_of
[
other_r
]
other_shape
=
self
.
shape_of
[
other_r
]
...
@@ -1303,8 +1326,7 @@ class ShapeFeature(object):
...
@@ -1303,8 +1326,7 @@ class ShapeFeature(object):
class
ShapeOptimizer
(
Optimizer
):
class
ShapeOptimizer
(
Optimizer
):
"""Optimizer that serves to add ShapeFeature as an fgraph feature.
"""Optimizer that serves to add ShapeFeature as an fgraph feature."""
"""
def
__init__
(
self
):
def
__init__
(
self
):
Optimizer
.
__init__
(
self
)
Optimizer
.
__init__
(
self
)
...
@@ -1392,6 +1414,7 @@ def local_useless_alloc(node):
...
@@ -1392,6 +1414,7 @@ def local_useless_alloc(node):
If the input type is the same as the output type (dtype and broadcast)
If the input type is the same as the output type (dtype and broadcast)
there is no change in the shape of the input. So this is just a simple copy
there is no change in the shape of the input. So this is just a simple copy
of the input. This is not needed.
of the input. This is not needed.
"""
"""
if
node
.
op
==
T
.
alloc
:
if
node
.
op
==
T
.
alloc
:
if
node
.
inputs
[
0
]
.
type
==
node
.
outputs
[
0
]
.
type
:
if
node
.
inputs
[
0
]
.
type
==
node
.
outputs
[
0
]
.
type
:
...
@@ -1438,14 +1461,15 @@ def local_track_shape_i(node):
...
@@ -1438,14 +1461,15 @@ def local_track_shape_i(node):
@gof.local_optimizer
([
Subtensor
,
AdvancedSubtensor1
])
@gof.local_optimizer
([
Subtensor
,
AdvancedSubtensor1
])
def
local_subtensor_make_vector
(
node
):
def
local_subtensor_make_vector
(
node
):
"""
"""
r
eplace all subtensor(make_vector) like:
R
eplace all subtensor(make_vector) like:
[a,b,c][0] -> a
[a,b,c][0] -> a
[a,b,c][0:2] -> [a,b]
[a,b,c][0:2] -> [a,b]
r
eplace all AdvancedSubtensor1(make_vector) like:
R
eplace all AdvancedSubtensor1(make_vector) like:
[a,b,c][[0,2]] -> [a,c]
[a,b,c][[0,2]] -> [a,c]
we can do this for constant indexes
We can do this for constant indexes.
"""
"""
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
if
not
x
.
owner
or
x
.
owner
.
op
!=
make_vector
:
if
not
x
.
owner
or
x
.
owner
.
op
!=
make_vector
:
...
@@ -1514,7 +1538,6 @@ def local_subtensor_make_vector(node):
...
@@ -1514,7 +1538,6 @@ def local_subtensor_make_vector(node):
@gof.local_optimizer
([
T
.
Elemwise
])
@gof.local_optimizer
([
T
.
Elemwise
])
def
local_useless_elemwise
(
node
):
def
local_useless_elemwise
(
node
):
"""
"""
eq(x,x) -> 1
eq(x,x) -> 1
neq(x,x) -> 0
neq(x,x) -> 0
mul(x) -> x
mul(x) -> x
...
@@ -1559,8 +1582,7 @@ def local_useless_elemwise(node):
...
@@ -1559,8 +1582,7 @@ def local_useless_elemwise(node):
@register_specialize
@register_specialize
@gof.local_optimizer
([
T
.
Elemwise
])
@gof.local_optimizer
([
T
.
Elemwise
])
def
local_alloc_unary
(
node
):
def
local_alloc_unary
(
node
):
"""unary(alloc(x, shp)) -> alloc(unary(x), shp)
"""unary(alloc(x, shp)) -> alloc(unary(x), shp)"""
"""
if
isinstance
(
node
.
op
,
T
.
Elemwise
)
and
len
(
node
.
inputs
)
==
1
:
if
isinstance
(
node
.
op
,
T
.
Elemwise
)
and
len
(
node
.
inputs
)
==
1
:
a
=
node
.
inputs
[
0
]
a
=
node
.
inputs
[
0
]
if
a
.
owner
and
isinstance
(
a
.
owner
.
op
,
T
.
Alloc
):
if
a
.
owner
and
isinstance
(
a
.
owner
.
op
,
T
.
Alloc
):
...
@@ -1587,6 +1609,7 @@ def local_cast_cast(node):
...
@@ -1587,6 +1609,7 @@ def local_cast_cast(node):
dtype1 == dtype2
dtype1 == dtype2
TODO: the base dtype is the same (int, uint, float, complex)
TODO: the base dtype is the same (int, uint, float, complex)
and the first cast cause an upcast.
and the first cast cause an upcast.
"""
"""
if
(
not
isinstance
(
node
.
op
,
T
.
Elemwise
)
or
if
(
not
isinstance
(
node
.
op
,
T
.
Elemwise
)
or
not
isinstance
(
node
.
op
.
scalar_op
,
scalar
.
Cast
)):
not
isinstance
(
node
.
op
.
scalar_op
,
scalar
.
Cast
)):
...
@@ -1607,9 +1630,9 @@ def local_cast_cast(node):
...
@@ -1607,9 +1630,9 @@ def local_cast_cast(node):
def
local_func_inv
(
node
):
def
local_func_inv
(
node
):
"""
"""
Check for two consecutive operations that are functional inverses
Check for two consecutive operations that are functional inverses
and remove them from the function graph
and remove them from the function graph.
"""
"""
inv_pairs
=
(
inv_pairs
=
(
(
basic
.
Deg2Rad
,
basic
.
Rad2Deg
),
(
basic
.
Deg2Rad
,
basic
.
Rad2Deg
),
(
basic
.
Cosh
,
basic
.
ArcCosh
),
(
basic
.
Cosh
,
basic
.
ArcCosh
),
...
@@ -1641,9 +1664,9 @@ def local_func_inv(node):
...
@@ -1641,9 +1664,9 @@ def local_func_inv(node):
def
is_inverse_pair
(
node_op
,
prev_op
,
inv_pair
):
def
is_inverse_pair
(
node_op
,
prev_op
,
inv_pair
):
"""
"""
Given two consecutive operations, check if they are the
Given two consecutive operations, check if they are the
provided pair of inverse functions
provided pair of inverse functions.
"""
"""
node_is_op0
=
isinstance
(
node_op
,
inv_pair
[
0
])
node_is_op0
=
isinstance
(
node_op
,
inv_pair
[
0
])
node_is_op1
=
isinstance
(
node_op
,
inv_pair
[
1
])
node_is_op1
=
isinstance
(
node_op
,
inv_pair
[
1
])
prev_is_op0
=
isinstance
(
prev_op
,
inv_pair
[
0
])
prev_is_op0
=
isinstance
(
prev_op
,
inv_pair
[
0
])
...
@@ -1659,20 +1682,24 @@ class Assert(T.Op):
...
@@ -1659,20 +1682,24 @@ class Assert(T.Op):
Returns the first parameter if the condition is true, otherwise, triggers
Returns the first parameter if the condition is true, otherwise, triggers
AssertionError.
AssertionError.
Example:
Notes
T = theano.tensor
-----
x = T.vector('x')
assert_op = T.opt.Assert()
func = theano.function([x], assert_op(x, x.size<2))
Notes:
This Op is a debugging feature. It can be removed from the graph
This Op is a debugging feature. It can be removed from the graph
because of optimizations, and can hide some possible optimizations to
because of optimizations, and can hide some possible optimizations to
the optimizer. Specifically, removing happens if it can be determined
the optimizer. Specifically, removing happens if it can be determined
that condition will always be true. Also, the output of the Op must be
that condition will always be true. Also, the output of the Op must be
used in the function computing the graph, but it doesn't have to be
used in the function computing the graph, but it doesn't have to be
returned.
returned.
Examples
--------
T = theano.tensor
x = T.vector('x')
assert_op = T.opt.Assert()
func = theano.function([x], assert_op(x, x.size<2))
"""
"""
__props__
=
(
'msg'
,)
__props__
=
(
'msg'
,)
view_map
=
{
0
:
[
0
]}
view_map
=
{
0
:
[
0
]}
...
@@ -1770,7 +1797,9 @@ def local_remove_all_assert(node):
...
@@ -1770,7 +1797,9 @@ def local_remove_all_assert(node):
"""An optimization disabled by default that removes all asserts from
"""An optimization disabled by default that removes all asserts from
the graph.
the graph.
:note: See the :ref:`unsafe` section to know how to enable it.
Notes
-----
See the :ref:`unsafe` section to know how to enable it.
"""
"""
if
not
isinstance
(
node
.
op
,
Assert
):
if
not
isinstance
(
node
.
op
,
Assert
):
...
@@ -1804,11 +1833,12 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
...
@@ -1804,11 +1833,12 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
BROADCAST CONDITION: the condition is that the one input that are
BROADCAST CONDITION: the condition is that the one input that are
not to be optimized to have the same broadcast pattern as the
not to be optimized to have the same broadcast pattern as the
output
output.
We can change the alloc by a dimshuffle as the elemwise
already have the shape info. The dimshuffle will be faster
to exec.
We can change the alloc by a dimshuffle as the elemwise
already have the shape info. The dimshuffle will be faster
to exec
"""
"""
if
not
isinstance
(
node
.
op
,
ElemwiseOP
):
if
not
isinstance
(
node
.
op
,
ElemwiseOP
):
return
False
return
False
...
@@ -1969,6 +1999,7 @@ def local_upcast_elemwise_constant_inputs(node):
...
@@ -1969,6 +1999,7 @@ def local_upcast_elemwise_constant_inputs(node):
those Ops do implicit upcasting anyway.
those Ops do implicit upcasting anyway.
Rationale: it helps merge things like (1-x) and (1.0 - x).
Rationale: it helps merge things like (1-x) and (1.0 - x).
"""
"""
if
len
(
node
.
outputs
)
>
1
:
if
len
(
node
.
outputs
)
>
1
:
return
return
...
@@ -2033,7 +2064,8 @@ def local_upcast_elemwise_constant_inputs(node):
...
@@ -2033,7 +2064,8 @@ def local_upcast_elemwise_constant_inputs(node):
@register_specialize
@register_specialize
@gof.local_optimizer
([
IncSubtensor
])
@gof.local_optimizer
([
IncSubtensor
])
def
local_useless_inc_subtensor
(
node
):
def
local_useless_inc_subtensor
(
node
):
"""Remove IncSubtensor, when we overwrite the full inputs with the
"""
Remove IncSubtensor, when we overwrite the full inputs with the
new value.
new value.
"""
"""
...
@@ -2082,6 +2114,7 @@ def local_set_to_inc_subtensor(node):
...
@@ -2082,6 +2114,7 @@ def local_set_to_inc_subtensor(node):
"""
"""
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
"""
"""
if
(
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
)
and
if
(
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
)
and
node
.
op
.
set_instead_of_inc
and
node
.
op
.
set_instead_of_inc
and
...
@@ -2144,6 +2177,7 @@ def local_useless_subtensor(node):
...
@@ -2144,6 +2177,7 @@ def local_useless_subtensor(node):
AdvancedSubtensor1 case, the full input is taken when the indices are
AdvancedSubtensor1 case, the full input is taken when the indices are
equivalent to `arange(0, input.shape[0], 1)` using either an explicit
equivalent to `arange(0, input.shape[0], 1)` using either an explicit
list/vector or the ARange op.
list/vector or the ARange op.
"""
"""
# This optimization needs ShapeOpt and fgraph.shape_feature
# This optimization needs ShapeOpt and fgraph.shape_feature
if
not
hasattr
(
node
.
fgraph
,
'shape_feature'
):
if
not
hasattr
(
node
.
fgraph
,
'shape_feature'
):
...
@@ -2261,6 +2295,7 @@ def local_subtensor_lift(node):
...
@@ -2261,6 +2295,7 @@ def local_subtensor_lift(node):
elemwise(x,...)[idx] -> elemwise(x[idx],...)
elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all
when x,... are broadcasted scalar or not broadcasted at all
rebroadcast(x)[idx] => rebroadcast(x[idx])
rebroadcast(x)[idx] => rebroadcast(x[idx])
"""
"""
if
isinstance
(
node
.
op
,
Subtensor
):
if
isinstance
(
node
.
op
,
Subtensor
):
u
=
node
.
inputs
[
0
]
u
=
node
.
inputs
[
0
]
...
@@ -2327,7 +2362,7 @@ def local_subtensor_lift(node):
...
@@ -2327,7 +2362,7 @@ def local_subtensor_lift(node):
def
merge_two_slices
(
slice1
,
len1
,
slice2
,
len2
):
def
merge_two_slices
(
slice1
,
len1
,
slice2
,
len2
):
'''
"""
This function merges two slices into a single slice. The code works on
This function merges two slices into a single slice. The code works on
the assumption that:
the assumption that:
a) slice1 is actually a slice and not an index, while slice2
a) slice1 is actually a slice and not an index, while slice2
...
@@ -2340,7 +2375,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
...
@@ -2340,7 +2375,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
the two consecutive slices.
the two consecutive slices.
``len1`` is the length of the tensor **before** applying the first slice,
``len1`` is the length of the tensor **before** applying the first slice,
while ``len2`` is the length **after** applying the first slice.
while ``len2`` is the length **after** applying the first slice.
'''
"""
list_opt
=
[
local_abs_merge
,
local_mul_switch_sink
,
list_opt
=
[
local_abs_merge
,
local_mul_switch_sink
,
local_upcast_elemwise_constant_inputs
,
local_upcast_elemwise_constant_inputs
,
local_remove_switch_const_cond
,
constant_folding
]
local_remove_switch_const_cond
,
constant_folding
]
...
@@ -2466,6 +2501,7 @@ def local_subtensor_merge(node):
...
@@ -2466,6 +2501,7 @@ def local_subtensor_merge(node):
Refactored optimization to deal with all cases of tensor merging.
Refactored optimization to deal with all cases of tensor merging.
Given a subgraph of the form Subtensor(Subtensor(u)), the optimization
Given a subgraph of the form Subtensor(Subtensor(u)), the optimization
expresses all slices in a canonical form, and then merges them together.
expresses all slices in a canonical form, and then merges them together.
"""
"""
if
isinstance
(
node
.
op
,
Subtensor
):
if
isinstance
(
node
.
op
,
Subtensor
):
...
@@ -2601,7 +2637,8 @@ def local_subtensor_of_dot(node):
...
@@ -2601,7 +2637,8 @@ def local_subtensor_of_dot(node):
idxs_a is the first A.ndim-1 entries of idxs,
idxs_a is the first A.ndim-1 entries of idxs,
and idxs_b is the remaining entries of idxs (if any),
and idxs_b is the remaining entries of idxs (if any),
modified to skip the second-to-last dimension of B
modified to skip the second-to-last dimension of B
(because dot sums over this dimension)
(because dot sums over this dimension).
"""
"""
if
not
isinstance
(
node
.
op
,
Subtensor
):
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
return
...
@@ -2715,7 +2752,8 @@ compile.optdb.register('pre_local_IncSubtensor_serialize',
...
@@ -2715,7 +2752,8 @@ compile.optdb.register('pre_local_IncSubtensor_serialize',
@gof.local_optimizer
([
IncSubtensor
],
inplace
=
True
)
@gof.local_optimizer
([
IncSubtensor
],
inplace
=
True
)
def
local_inplace_setsubtensor
(
node
):
def
local_inplace_setsubtensor
(
node
):
"""
"""
Also work for GpuIncSubtensor
Also work for GpuIncSubtensor.
"""
"""
if
isinstance
(
node
.
op
,
IncSubtensor
)
and
not
node
.
op
.
inplace
:
if
isinstance
(
node
.
op
,
IncSubtensor
)
and
not
node
.
op
.
inplace
:
new_op
=
node
.
op
.
__class__
(
new_op
=
node
.
op
.
__class__
(
...
@@ -2734,7 +2772,10 @@ compile.optdb.register('local_inplace_setsubtensor',
...
@@ -2734,7 +2772,10 @@ compile.optdb.register('local_inplace_setsubtensor',
@gof.local_optimizer
([
AdvancedIncSubtensor1
],
inplace
=
True
)
@gof.local_optimizer
([
AdvancedIncSubtensor1
],
inplace
=
True
)
def
local_inplace_incsubtensor1
(
node
):
def
local_inplace_incsubtensor1
(
node
):
""" also work for GpuAdvancedIncSubtensor1 """
"""
Also work for GpuAdvancedIncSubtensor1.
"""
if
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
)
and
not
node
.
op
.
inplace
:
if
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
)
and
not
node
.
op
.
inplace
:
new_op
=
node
.
op
.
clone_inplace
()
new_op
=
node
.
op
.
clone_inplace
()
new_node
=
new_op
(
*
node
.
inputs
)
new_node
=
new_op
(
*
node
.
inputs
)
...
@@ -2756,6 +2797,7 @@ compile.optdb.register('local_inplace_incsubtensor1',
...
@@ -2756,6 +2797,7 @@ compile.optdb.register('local_inplace_incsubtensor1',
def
local_incsubtensor_of_zeros
(
node
):
def
local_incsubtensor_of_zeros
(
node
):
"""
"""
IncSubtensor(x, zeros, idx) -> x
IncSubtensor(x, zeros, idx) -> x
"""
"""
if
(
isinstance
(
node
.
op
,
(
IncSubtensor
,
if
(
isinstance
(
node
.
op
,
(
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor
,
...
@@ -2784,6 +2826,7 @@ def local_setsubtensor_of_constants(node):
...
@@ -2784,6 +2826,7 @@ def local_setsubtensor_of_constants(node):
SetSubtensor(x, x[idx], idx) -> x
SetSubtensor(x, x[idx], idx) -> x
when x is constant or alloc.
when x is constant or alloc.
"""
"""
if
isinstance
(
node
.
op
,
IncSubtensor
)
and
node
.
op
.
set_instead_of_inc
:
if
isinstance
(
node
.
op
,
IncSubtensor
)
and
node
.
op
.
set_instead_of_inc
:
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
...
@@ -2813,14 +2856,16 @@ def local_setsubtensor_of_constants(node):
...
@@ -2813,14 +2856,16 @@ def local_setsubtensor_of_constants(node):
@register_stabilize
@register_stabilize
@gof.local_optimizer
([
AdvancedSubtensor1
])
@gof.local_optimizer
([
AdvancedSubtensor1
])
def
local_adv_sub1_adv_inc_sub1
(
node
):
def
local_adv_sub1_adv_inc_sub1
(
node
):
"""Optimize the possible AdvSub1(AdvIncSub1(...), ...)
"""Optimize the possible AdvSub1(AdvIncSub1(...), ...)
.
AdvancedSubtensor1(AdvancedIncSubtensor1(0s, y, idx), idx) -> y
AdvancedSubtensor1(AdvancedIncSubtensor1(0s, y, idx), idx) -> y
AdvancedSubtensor1(AdvancedSetSubtensor1(x, y, idx), idx) -> y
AdvancedSubtensor1(AdvancedSetSubtensor1(x, y, idx), idx) -> y
:note: This opt add AssertOp. Otherwise, it would remove shape and
Notes
index error. If you want to get rid of them, see the
-----
:ref:`unsafe_optimization` section.
This opt add AssertOp. Otherwise, it would remove shape and
index error. If you want to get rid of them, see the
:ref:`unsafe_optimization` section.
"""
"""
if
not
isinstance
(
node
.
op
,
AdvancedSubtensor1
):
if
not
isinstance
(
node
.
op
,
AdvancedSubtensor1
):
...
@@ -2862,6 +2907,7 @@ def local_useless_inc_subtensor_alloc(node):
...
@@ -2862,6 +2907,7 @@ def local_useless_inc_subtensor_alloc(node):
Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of
Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of
a fully or partially broadcastable variable, by one that skips the
a fully or partially broadcastable variable, by one that skips the
intermediate `alloc` where possible.
intermediate `alloc` where possible.
"""
"""
if
isinstance
(
node
.
op
,
(
IncSubtensor
,
if
isinstance
(
node
.
op
,
(
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor
,
...
@@ -2962,7 +3008,8 @@ def local_useless_inc_subtensor_alloc(node):
...
@@ -2962,7 +3008,8 @@ def local_useless_inc_subtensor_alloc(node):
@gof.local_optimizer
([
T
.
Rebroadcast
])
@gof.local_optimizer
([
T
.
Rebroadcast
])
def
local_useless_rebroadcast
(
node
):
def
local_useless_rebroadcast
(
node
):
"""
"""
Remove Rebroadcast if id does not actually change the broadcasting pattern
Remove Rebroadcast if id does not actually change the broadcasting pattern.
"""
"""
if
isinstance
(
node
.
op
,
T
.
Rebroadcast
):
if
isinstance
(
node
.
op
,
T
.
Rebroadcast
):
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
...
@@ -2992,6 +3039,7 @@ def local_rebroadcast_lift(node):
...
@@ -2992,6 +3039,7 @@ def local_rebroadcast_lift(node):
Rebroadcast(Elemwise(x)) => Elemwise(Rebroadcast(x))
Rebroadcast(Elemwise(x)) => Elemwise(Rebroadcast(x))
Rebroadcast(Rebroadcast(x)) => Rebroadcast(x)
Rebroadcast(Rebroadcast(x)) => Rebroadcast(x)
"""
"""
op
=
node
.
op
op
=
node
.
op
if
not
isinstance
(
op
,
T
.
Rebroadcast
):
if
not
isinstance
(
op
,
T
.
Rebroadcast
):
...
@@ -3023,8 +3071,14 @@ def apply_rebroadcast_opt(rval):
...
@@ -3023,8 +3071,14 @@ def apply_rebroadcast_opt(rval):
Apply as many times as required the optimization local_useless_rebroadcast
Apply as many times as required the optimization local_useless_rebroadcast
and local_rebroadcast_lift.
and local_rebroadcast_lift.
:param rval: a Variable
Parameters
:return: a Variable (the same if no optimization can be applied)
----------
rval: a Variable
Returns
-------
A Variable (the same if no optimization can be applied)
"""
"""
changed
=
True
changed
=
True
...
@@ -3056,6 +3110,7 @@ def local_join_1(node):
...
@@ -3056,6 +3110,7 @@ def local_join_1(node):
"""Join(i, x) => x
"""Join(i, x) => x
Remove Join() when only one element is joined.
Remove Join() when only one element is joined.
"""
"""
if
not
isinstance
(
node
.
op
,
T
.
Join
):
if
not
isinstance
(
node
.
op
,
T
.
Join
):
return
return
...
@@ -3070,7 +3125,8 @@ def local_join_1(node):
...
@@ -3070,7 +3125,8 @@ def local_join_1(node):
def
local_join_empty
(
node
):
def
local_join_empty
(
node
):
"""Join(i, x, y, empty) => Join(i, x, y)
"""Join(i, x, y, empty) => Join(i, x, y)
remove empty inputs to joins. The empty inputs can be anywhere.
Remove empty inputs to joins. The empty inputs can be anywhere.
"""
"""
if
not
isinstance
(
node
.
op
,
T
.
Join
):
if
not
isinstance
(
node
.
op
,
T
.
Join
):
return
return
...
@@ -3147,6 +3203,7 @@ def local_remove_switch_const_cond(node):
...
@@ -3147,6 +3203,7 @@ def local_remove_switch_const_cond(node):
T.switch(cond,left,right) -->
T.switch(cond,left,right) -->
if cond is constant and cond == 0: right
if cond is constant and cond == 0: right
if cond is constant and cond != 0: left
if cond is constant and cond != 0: left
"""
"""
if
(
isinstance
(
node
.
op
,
T
.
Elemwise
)
and
if
(
isinstance
(
node
.
op
,
T
.
Elemwise
)
and
isinstance
(
node
.
op
.
scalar_op
,
scalar
.
basic
.
Switch
)):
isinstance
(
node
.
op
.
scalar_op
,
scalar
.
basic
.
Switch
)):
...
@@ -3183,7 +3240,9 @@ def local_mul_switch_sink(node):
...
@@ -3183,7 +3240,9 @@ def local_mul_switch_sink(node):
This is useful because A and B may not be numerically stable and give
This is useful because A and B may not be numerically stable and give
NaN or inf values for cases where the switch returns 0.
NaN or inf values for cases where the switch returns 0.
With this optimization T.grad(T.switch(...)) has the right behavior.
With this optimization T.grad(T.switch(...)) has the right behavior.
Exemple:
Examples
--------
x -> f(x)
x -> f(x)
x -> g(x)
x -> g(x)
y = T.switch(cond,f(x),g(x))
y = T.switch(cond,f(x),g(x))
...
@@ -3193,6 +3252,7 @@ def local_mul_switch_sink(node):
...
@@ -3193,6 +3252,7 @@ def local_mul_switch_sink(node):
T.grad(y,x) -> switch(cond,grad(f(x),x), 0) + switch(cond,0,grad(g(x),x))
T.grad(y,x) -> switch(cond,grad(f(x),x), 0) + switch(cond,0,grad(g(x),x))
This will be particularly useful for the lazyif because we skip
This will be particularly useful for the lazyif because we skip
an entire part of the graph.
an entire part of the graph.
"""
"""
if
node
.
op
!=
T
.
mul
:
if
node
.
op
!=
T
.
mul
:
return
False
return
False
...
@@ -3234,6 +3294,7 @@ def local_div_switch_sink(node):
...
@@ -3234,6 +3294,7 @@ def local_div_switch_sink(node):
This is useful because A may not be numerically stable and give
This is useful because A may not be numerically stable and give
NaN or inf values for cases where the switch returns 0.
NaN or inf values for cases where the switch returns 0.
See local_mul_switch_sink for more details.
See local_mul_switch_sink for more details.
"""
"""
if
(
node
.
op
!=
T
.
true_div
and
node
.
op
!=
T
.
int_div
):
if
(
node
.
op
!=
T
.
true_div
and
node
.
op
!=
T
.
int_div
):
return
False
return
False
...
@@ -3308,6 +3369,7 @@ def local_useless_split(node):
...
@@ -3308,6 +3369,7 @@ def local_useless_split(node):
""" Split{n_splits=1}(x, y) -> x
""" Split{n_splits=1}(x, y) -> x
Remove Split with only 1 split.
Remove Split with only 1 split.
"""
"""
if
isinstance
(
node
.
op
,
T
.
Split
):
if
isinstance
(
node
.
op
,
T
.
Split
):
if
node
.
op
.
len_splits
==
1
:
if
node
.
op
.
len_splits
==
1
:
...
@@ -3329,6 +3391,7 @@ def local_flatten_lift(node):
...
@@ -3329,6 +3391,7 @@ def local_flatten_lift(node):
This optimization is needed by optimization
This optimization is needed by optimization
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten.
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten.
"""
"""
if
(
isinstance
(
node
.
op
,
T
.
Flatten
)
and
if
(
isinstance
(
node
.
op
,
T
.
Flatten
)
and
node
.
inputs
[
0
]
.
owner
and
node
.
inputs
[
0
]
.
owner
and
...
@@ -3347,6 +3410,7 @@ def local_flatten_lift(node):
...
@@ -3347,6 +3410,7 @@ def local_flatten_lift(node):
def
local_reshape_chain
(
node
):
def
local_reshape_chain
(
node
):
"""
"""
Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
"""
"""
if
not
opt
.
check_chain
(
node
,
T
.
Reshape
,
T
.
Reshape
):
if
not
opt
.
check_chain
(
node
,
T
.
Reshape
,
T
.
Reshape
):
return
False
return
False
...
@@ -3378,6 +3442,7 @@ def local_reshape_lift(node):
...
@@ -3378,6 +3442,7 @@ def local_reshape_lift(node):
This optimization is needed by optimization
This optimization is needed by optimization
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a reshape.
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a reshape.
"""
"""
if
(
isinstance
(
node
.
op
,
T
.
Reshape
)
and
if
(
isinstance
(
node
.
op
,
T
.
Reshape
)
and
node
.
inputs
[
0
]
.
owner
and
node
.
inputs
[
0
]
.
owner
and
...
@@ -3526,26 +3591,32 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -3526,26 +3591,32 @@ class Canonizer(gof.LocalOptimizer):
Usage: Canonizer(main, inverse, reciprocal, calculate)
Usage: Canonizer(main, inverse, reciprocal, calculate)
* main: a suitable Op class that is commutative, associative and
Parameters
takes one to an arbitrary number of inputs, e.g. add or
----------
mul
main
* inverse: an Op class such that inverse(main(x, y), y) == x
A suitable Op class that is commutative, associative and
e.g. sub or true_div
takes one to an arbitrary number of inputs, e.g. add or
* reciprocal: a function such that main(x, reciprocal(y)) ==
mul
inverse(x, y) e.g. neg or inv
inverse
An Op class such that inverse(main(x, y), y) == x
* calculate: function that takes a list of numpy.ndarray instances
e.g. sub or true_div
for the numerator, another list for the denumerator,
reciprocal
and calculates inverse(main(*num), main(*denum)). It
A function such that main(x, reciprocal(y)) == inverse(x, y)
takes a keyword argument, aslist. If True, the value
e.g. neg or inv
should be returned as a list of one element, unless
calculate
the value is such that value = main(). In that case,
Function that takes a list of numpy.ndarray instances
the return value should be an empty list.
for the numerator, another list for the denumerator,
and calculates inverse(main(*num), main(*denum)). It
takes a keyword argument, aslist. If True, the value
should be returned as a list of one element, unless
the value is such that value = main(). In that case,
the return value should be an empty list.
The variable is a local_optimizer. It is best used with a TopoOptimizer in
The variable is a local_optimizer. It is best used with a TopoOptimizer in
in_to_out order.
in_to_out order.
Examples:
Examples
--------
T = theano.tensor
T = theano.tensor
add_canonizer = Canonizer(T.add, T.sub, T.neg,
add_canonizer = Canonizer(T.add, T.sub, T.neg,
lambda n, d: sum(n) - sum(d))
lambda n, d: sum(n) - sum(d))
...
@@ -3563,6 +3634,7 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -3563,6 +3634,7 @@ class Canonizer(gof.LocalOptimizer):
2 * x / 2 -> x
2 * x / 2 -> x
x * y * z -> Elemwise(T.mul){x,y,z} #only one pass over the memory.
x * y * z -> Elemwise(T.mul){x,y,z} #only one pass over the memory.
!-> Elemwise(T.mul){x,Elemwise(T.mul){y,z}}
!-> Elemwise(T.mul){x,Elemwise(T.mul){y,z}}
"""
"""
def
__init__
(
self
,
main
,
inverse
,
reciprocal
,
calculate
,
def
__init__
(
self
,
main
,
inverse
,
reciprocal
,
calculate
,
...
@@ -3747,8 +3819,11 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -3747,8 +3819,11 @@ class Canonizer(gof.LocalOptimizer):
@staticmethod
@staticmethod
def
get_constant
(
v
):
def
get_constant
(
v
):
"""
"""
Returns a numeric constant if v is a Constant or, well, a
Returns
numeric constant. If v is a plain Variable, returns None.
-------
A numeric constant if v is a Constant or, well, a
numeric constant. If v is a plain Variable, returns None.
"""
"""
if
isinstance
(
v
,
Variable
):
if
isinstance
(
v
,
Variable
):
try
:
try
:
...
@@ -3762,6 +3837,7 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -3762,6 +3837,7 @@ class Canonizer(gof.LocalOptimizer):
"""
"""
Shorthand for:
Shorthand for:
self.simplify_constants(*self.simplify_factors(num, denum))
self.simplify_constants(*self.simplify_factors(num, denum))
"""
"""
rval
=
self
.
simplify_constants
(
*
self
.
simplify_factors
(
num
,
denum
),
rval
=
self
.
simplify_constants
(
*
self
.
simplify_factors
(
num
,
denum
),
out_type
=
out_type
)
out_type
=
out_type
)
...
@@ -3781,6 +3857,7 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -3781,6 +3857,7 @@ class Canonizer(gof.LocalOptimizer):
[x], [x] -> [], []
[x], [x] -> [], []
[x, y], [x] -> [y], []
[x, y], [x] -> [y], []
[a, b], [c, d] -> [a, b], [c, d]
[a, b], [c, d] -> [a, b], [c, d]
"""
"""
for
v
in
list
(
num
):
for
v
in
list
(
num
):
if
v
in
denum
:
if
v
in
denum
:
...
@@ -3790,18 +3867,22 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -3790,18 +3867,22 @@ class Canonizer(gof.LocalOptimizer):
def
simplify_constants
(
self
,
orig_num
,
orig_denum
,
out_type
=
None
):
def
simplify_constants
(
self
,
orig_num
,
orig_denum
,
out_type
=
None
):
"""
"""
Find all constants and put them together into a single constant.
Finds all constants in orig_num and orig_denum (using
Finds all constants in orig_num and orig_denum (using
get_constant) and puts them together into a single
get_constant) and puts them together into a single
constant. The constant is inserted as the first element of the
constant. The constant is inserted as the first element of the
numerator. If the constant is the neutral element, it is
numerator. If the constant is the neutral element, it is
removed from the numerator.
Examples:
removed from the numerator.
Examples
--------
Let main be multiplication:
Let main be multiplication:
[2, 3, x], [] -> [6, x], []
[2, 3, x], [] -> [6, x], []
[x, y, 2], [4, z] -> [0.5, x, y], [z]
[x, y, 2], [4, z] -> [0.5, x, y], [z]
[x, 2, y], [z, 2] -> [x, y], [z]
[x, 2, y], [z, 2] -> [x, y], [z]
"""
"""
# Lists representing the numerator and denumerator
# Lists representing the numerator and denumerator
...
@@ -3969,13 +4050,15 @@ register_canonicalize(local_neg_to_mul)
...
@@ -3969,13 +4050,15 @@ register_canonicalize(local_neg_to_mul)
@register_specialize
@register_specialize
@gof.local_optimizer
([
T
.
Sum
,
T
.
elemwise
.
Prod
])
@gof.local_optimizer
([
T
.
Sum
,
T
.
elemwise
.
Prod
])
def
local_sum_prod_mul_by_scalar
(
node
):
def
local_sum_prod_mul_by_scalar
(
node
):
"""sum(scalar * smth) -> scalar * sum(smth)
"""
sum(-smth) -> -sum(smth)
sum(scalar * smth) -> scalar * sum(smth)
sum(-smth) -> -sum(smth)
or
or
prod(scalar * smth) -> scalar ** size(smth) * prod(smth)
prod(-smth) -> -1 ** size(smth) * prod(smth)
prod(scalar * smth) -> scalar ** size(smth) * prod(smth)
prod(-smth) -> -1 ** size(smth) * prod(smth)
"""
"""
# TODO: if the the thing inside the Sum is a division,
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
# we should get at the numerator....
...
@@ -4040,8 +4123,11 @@ def local_elemwise_sub_zeros(node):
...
@@ -4040,8 +4123,11 @@ def local_elemwise_sub_zeros(node):
@register_specialize
@register_specialize
@gof.local_optimizer
([
T
.
Sum
])
@gof.local_optimizer
([
T
.
Sum
])
def
local_sum_div_dimshuffle
(
node
):
def
local_sum_div_dimshuffle
(
node
):
'''sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
"""
if dimension l of the DimShuffle is 'x'.'''
sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
if dimension l of the DimShuffle is 'x'.
"""
# TODO: extend it to product, and quotient of products
# TODO: extend it to product, and quotient of products
# It does not make much sense now to extend it to the case where the
# It does not make much sense now to extend it to the case where the
...
@@ -4128,8 +4214,10 @@ def local_sum_div_dimshuffle(node):
...
@@ -4128,8 +4214,10 @@ def local_sum_div_dimshuffle(node):
@register_canonicalize
@register_canonicalize
@gof.local_optimizer
([
T
.
Sum
,
T
.
elemwise
.
Prod
])
@gof.local_optimizer
([
T
.
Sum
,
T
.
elemwise
.
Prod
])
def
local_sum_prod_all_to_none
(
node
):
def
local_sum_prod_all_to_none
(
node
):
"""Sum{0,1,...N} -> Sum{} or
"""
Prod{0,1,...N} -> Prod{}
Sum{0,1,...N} -> Sum{} or
Prod{0,1,...N} -> Prod{}
"""
"""
if
isinstance
(
node
.
op
,
T
.
Sum
)
or
isinstance
(
node
.
op
,
T
.
elemwise
.
Prod
):
if
isinstance
(
node
.
op
,
T
.
Sum
)
or
isinstance
(
node
.
op
,
T
.
elemwise
.
Prod
):
opt_type
=
T
.
Sum
if
isinstance
(
node
.
op
,
T
.
Sum
)
else
T
.
elemwise
.
Prod
opt_type
=
T
.
Sum
if
isinstance
(
node
.
op
,
T
.
Sum
)
else
T
.
elemwise
.
Prod
...
@@ -4148,6 +4236,7 @@ def local_op_of_op(node):
...
@@ -4148,6 +4236,7 @@ def local_op_of_op(node):
Prod(Prod()) -> single Prod()
Prod(Prod()) -> single Prod()
or
or
Sum(Sum()) -> single Sum()
Sum(Sum()) -> single Sum()
"""
"""
if
isinstance
(
node
.
op
,
T
.
elemwise
.
Prod
)
or
isinstance
(
node
.
op
,
T
.
Sum
):
if
isinstance
(
node
.
op
,
T
.
elemwise
.
Prod
)
or
isinstance
(
node
.
op
,
T
.
Sum
):
opt_type
=
T
.
Sum
if
isinstance
(
node
.
op
,
T
.
Sum
)
else
T
.
elemwise
.
Prod
opt_type
=
T
.
Sum
if
isinstance
(
node
.
op
,
T
.
Sum
)
else
T
.
elemwise
.
Prod
...
@@ -4219,14 +4308,16 @@ ALL_REDUCE = [T.elemwise.CAReduce, T.elemwise.All, T.elemwise.Any,
...
@@ -4219,14 +4308,16 @@ ALL_REDUCE = [T.elemwise.CAReduce, T.elemwise.All, T.elemwise.Any,
@register_uncanonicalize
# Needed for MaxAndArgmax -> CAReduce
@register_uncanonicalize
# Needed for MaxAndArgmax -> CAReduce
@gof.local_optimizer
(
ALL_REDUCE
)
@gof.local_optimizer
(
ALL_REDUCE
)
def
local_reduce_join
(
node
):
def
local_reduce_join
(
node
):
"""Reduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
"""
Reduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
:note: supported scalar.op are Maximum, Mimimum in some cases and
Notes
Add and Mul in all cases.
-----
Supported scalar.op are Maximum, Mimimum in some cases and Add and Mul in
all cases.
:note: Currently we must reduce on axis 0. It is probably
Currently we must reduce on axis 0. It is probably extensible to the case
extensible to the case where we join and reduce on the same
where we join and reduce on the same set of axis.
set of axis.
"""
"""
if
(
isinstance
(
node
.
op
,
T
.
CAReduce
)
and
if
(
isinstance
(
node
.
op
,
T
.
CAReduce
)
and
...
@@ -4312,7 +4403,7 @@ def local_cut_useless_reduce(node):
...
@@ -4312,7 +4403,7 @@ def local_cut_useless_reduce(node):
@register_specialize
@register_specialize
@gof.local_optimizer
(
ALL_REDUCE
)
@gof.local_optimizer
(
ALL_REDUCE
)
def
local_reduce_broadcastable
(
node
):
def
local_reduce_broadcastable
(
node
):
"""Remove reduction over broadcastable dimensions"""
"""Remove reduction over broadcastable dimensions
.
"""
if
isinstance
(
node
.
op
,
T
.
CAReduce
):
if
isinstance
(
node
.
op
,
T
.
CAReduce
):
reduced
,
=
node
.
inputs
reduced
,
=
node
.
inputs
odtype
=
node
.
outputs
[
0
]
.
dtype
odtype
=
node
.
outputs
[
0
]
.
dtype
...
@@ -4351,9 +4442,11 @@ def local_reduce_broadcastable(node):
...
@@ -4351,9 +4442,11 @@ def local_reduce_broadcastable(node):
@register_specialize
@register_specialize
@gof.local_optimizer
([
T
.
Sum
,
T
.
elemwise
.
Prod
])
@gof.local_optimizer
([
T
.
Sum
,
T
.
elemwise
.
Prod
])
def
local_opt_alloc
(
node
):
def
local_opt_alloc
(
node
):
""" sum(alloc(constant,shapes...)) => constant*prod(shapes)
"""
or
sum(alloc(constant,shapes...)) => constant*prod(shapes)
prod(alloc(constant,shapes...)) => constant**prod(shapes)
or
prod(alloc(constant,shapes...)) => constant**prod(shapes)
"""
"""
if
isinstance
(
node
.
op
,
T
.
Sum
)
or
isinstance
(
node
.
op
,
T
.
elemwise
.
Prod
):
if
isinstance
(
node
.
op
,
T
.
Sum
)
or
isinstance
(
node
.
op
,
T
.
elemwise
.
Prod
):
node_inps
,
=
node
.
inputs
node_inps
,
=
node
.
inputs
...
@@ -4406,9 +4499,11 @@ def local_neg_neg(node):
...
@@ -4406,9 +4499,11 @@ def local_neg_neg(node):
@register_specialize
@register_specialize
@gof.local_optimizer
([
T
.
neg
])
@gof.local_optimizer
([
T
.
neg
])
def
local_neg_div_neg
(
node
):
def
local_neg_div_neg
(
node
):
"""- (-a / b) -> a / b
"""
- (-a / b) -> a / b
Also performs - (c / b) -> ((-c) / b) when c is a scalar constant.
Also performs - (c / b) -> ((-c) / b) when c is a scalar constant.
"""
"""
if
node
.
op
==
T
.
neg
:
if
node
.
op
==
T
.
neg
:
if
node
.
inputs
[
0
]
.
owner
and
node
.
inputs
[
0
]
.
owner
.
op
==
T
.
true_div
:
if
node
.
inputs
[
0
]
.
owner
and
node
.
inputs
[
0
]
.
owner
.
op
==
T
.
true_div
:
...
@@ -4427,8 +4522,10 @@ def local_neg_div_neg(node):
...
@@ -4427,8 +4522,10 @@ def local_neg_div_neg(node):
@gof.local_optimizer
([
T
.
mul
])
@gof.local_optimizer
([
T
.
mul
])
def
local_mul_zero
(
node
):
def
local_mul_zero
(
node
):
"""As part of canonicalization, we replace multiplication by zero
"""
As part of canonicalization, we replace multiplication by zero
with zero.
with zero.
"""
"""
if
node
.
op
==
T
.
mul
:
if
node
.
op
==
T
.
mul
:
otype
=
node
.
outputs
[
0
]
.
type
otype
=
node
.
outputs
[
0
]
.
type
...
@@ -4489,10 +4586,12 @@ register_canonicalize(local_pow_canonicalize)
...
@@ -4489,10 +4586,12 @@ register_canonicalize(local_pow_canonicalize)
@register_specialize
@register_specialize
@gof.local_optimizer
([
T
.
mul
])
@gof.local_optimizer
([
T
.
mul
])
def
local_mul_to_sqr
(
node
):
def
local_mul_to_sqr
(
node
):
"""x*x -> sqr(x)
"""
x*x -> sqr(x)
This is faster on the GPU when memory fetching is a big part of
This is faster on the GPU when memory fetching is a big part of
the computation time.
the computation time.
"""
"""
if
node
.
op
==
T
.
mul
:
if
node
.
op
==
T
.
mul
:
if
len
(
node
.
inputs
)
==
2
:
if
len
(
node
.
inputs
)
==
2
:
...
@@ -4620,7 +4719,8 @@ def local_pow_specialize_device(node):
...
@@ -4620,7 +4719,8 @@ def local_pow_specialize_device(node):
@gof.local_optimizer
([
T
.
mul
])
@gof.local_optimizer
([
T
.
mul
])
def
local_mul_specialize
(
node
):
def
local_mul_specialize
(
node
):
"""Remove special-case constants from mul arguments and useless neg in inputs.
"""
Remove special-case constants from mul arguments and useless neg in inputs.
mul(-1, x) -> neg(x)
mul(-1, x) -> neg(x)
mul(1, x, y) -> mul(x, y)
mul(1, x, y) -> mul(x, y)
...
@@ -4629,6 +4729,7 @@ def local_mul_specialize(node):
...
@@ -4629,6 +4729,7 @@ def local_mul_specialize(node):
This is not done if we would add more nodes in the graph, like with:
This is not done if we would add more nodes in the graph, like with:
mul(-1, x, y) -/-> neg(mul(x, y))
mul(-1, x, y) -/-> neg(mul(x, y))
"""
"""
# here, we are past the point of canonicalization, so we don't
# here, we are past the point of canonicalization, so we don't
# want to put in un-necessary fills.
# want to put in un-necessary fills.
...
@@ -4766,8 +4867,9 @@ local_mul_canonizer.add_simplifier(check_for_x_over_absX, 'X_over_absX')
...
@@ -4766,8 +4867,9 @@ local_mul_canonizer.add_simplifier(check_for_x_over_absX, 'X_over_absX')
@gof.local_optimizer
([
T
.
abs_
])
@gof.local_optimizer
([
T
.
abs_
])
def
local_abs_lift
(
node
):
def
local_abs_lift
(
node
):
"""
"""
move the abs toward the input. This is needed for
Move the abs toward the input.
check_for_x_over_absX to apply in more case.
This is needed for check_for_x_over_absX to apply in more case.
"""
"""
if
node
.
op
==
T
.
abs_
and
node
.
inputs
[
0
]
.
owner
:
if
node
.
op
==
T
.
abs_
and
node
.
inputs
[
0
]
.
owner
:
...
@@ -4783,7 +4885,7 @@ def local_abs_lift(node):
...
@@ -4783,7 +4885,7 @@ def local_abs_lift(node):
@gof.local_optimizer
([
T
.
mul
,
T
.
true_div
])
@gof.local_optimizer
([
T
.
mul
,
T
.
true_div
])
def
local_abs_merge
(
node
):
def
local_abs_merge
(
node
):
"""
"""
m
erge abs generated by local_abs_lift when the canonizer don't
M
erge abs generated by local_abs_lift when the canonizer don't
need it anymore
need it anymore
"""
"""
...
@@ -4968,6 +5070,8 @@ def attempt_distribution(factor, num, denum, out_type):
...
@@ -4968,6 +5070,8 @@ def attempt_distribution(factor, num, denum, out_type):
@gof.local_optimizer
([
T
.
mul
,
T
.
true_div
,
T
.
inv
])
@gof.local_optimizer
([
T
.
mul
,
T
.
true_div
,
T
.
inv
])
def
local_greedy_distributor
(
node
):
def
local_greedy_distributor
(
node
):
"""
"""
Optimize by reducing the number of multiplications and/or divisions.
This optimization tries to apply distributivity of multiplication
This optimization tries to apply distributivity of multiplication
to addition in order to reduce the number of multiplications
to addition in order to reduce the number of multiplications
and/or divisions that must be done. The algorithm weighs division
and/or divisions that must be done. The algorithm weighs division
...
@@ -4985,6 +5089,7 @@ def local_greedy_distributor(node):
...
@@ -4985,6 +5089,7 @@ def local_greedy_distributor(node):
This optimization aims to reduce computational cost. It may also
This optimization aims to reduce computational cost. It may also
increase numerical stability, e.g. when x and/or y tend to 0 in
increase numerical stability, e.g. when x and/or y tend to 0 in
example 1.
example 1.
"""
"""
out
=
node
.
outputs
[
0
]
out
=
node
.
outputs
[
0
]
...
@@ -5083,7 +5188,12 @@ def constant_folding(node):
...
@@ -5083,7 +5188,12 @@ def constant_folding(node):
def
_is_1
(
expr
):
def
_is_1
(
expr
):
"""rtype bool. True iff expr is a constant close to 1
"""
Returns
-------
bool
True iff expr is a constant close to 1.
"""
"""
try
:
try
:
v
=
get_scalar_constant_value
(
expr
)
v
=
get_scalar_constant_value
(
expr
)
...
@@ -5093,7 +5203,12 @@ def _is_1(expr):
...
@@ -5093,7 +5203,12 @@ def _is_1(expr):
def
_is_minus1
(
expr
):
def
_is_minus1
(
expr
):
"""rtype bool. True iff expr is a constant close to -1
"""
Returns
-------
bool
True iff expr is a constant close to -1.
"""
"""
try
:
try
:
v
=
get_scalar_constant_value
(
expr
)
v
=
get_scalar_constant_value
(
expr
)
...
@@ -5103,13 +5218,13 @@ def _is_minus1(expr):
...
@@ -5103,13 +5218,13 @@ def _is_minus1(expr):
def
get_clients
(
node
):
def
get_clients
(
node
):
"
Used by erf/erfc opt to track less frequent op
"
"
""Used by erf/erfc opt to track less frequent op.""
"
return
[
c
for
c
,
i
in
node
.
outputs
[
0
]
.
clients
return
[
c
for
c
,
i
in
node
.
outputs
[
0
]
.
clients
if
c
!=
"output"
]
if
c
!=
"output"
]
def
get_clients2
(
node
):
def
get_clients2
(
node
):
"
Used by erf/erfc opt to track less frequent op
"
"
""Used by erf/erfc opt to track less frequent op.""
"
l
=
[]
l
=
[]
for
c
,
i
in
node
.
outputs
[
0
]
.
clients
:
for
c
,
i
in
node
.
outputs
[
0
]
.
clients
:
if
c
!=
"output"
:
if
c
!=
"output"
:
...
@@ -5622,18 +5737,22 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
...
@@ -5622,18 +5737,22 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
"""
"""
We parametrize it to make it work for Elemwise and GpuElemwise op.
We parametrize it to make it work for Elemwise and GpuElemwise op.
:param OP: GpuElemwise or Elemwise class (the one that we want to fuse)
Parameters
----------
:param max_input_fct: a function that returns the maximum number of inputs
OP
that this elemwise can take (useful for GpuElemwise).
GpuElemwise or Elemwise class (the one that we want to fuse)
GPU kernel currently has a limit of 256 bytes for
max_input_fct
the size of all parameters passed to it. As currently
A function that returns the maximum number of inputs
we pass many information only by parameter, we must
that this elemwise can take (useful for GpuElemwise).
limit how many ops we fuse together to avoid busting
GPU kernel currently has a limit of 256 bytes for
that 256 limit.
the size of all parameters passed to it. As currently
we pass many information only by parameter, we must
limit how many ops we fuse together to avoid busting
that 256 limit.
On the CPU we limit to 32 input variables
since that is the maximum numpy support.
On the CPU we limit to 32 input variables
since that is the maximum numpy support.
"""
"""
if
maker
is
None
:
if
maker
is
None
:
def
maker
(
node
,
scalar_op
):
def
maker
(
node
,
scalar_op
):
...
@@ -5647,6 +5766,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
...
@@ -5647,6 +5766,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
For mixed dtype, we let the Composite op do the cast. It lets the C
For mixed dtype, we let the Composite op do the cast. It lets the C
compiler do the cast.
compiler do the cast.
The number of dimensions is validated at call time by theano itself.
The number of dimensions is validated at call time by theano itself.
"""
"""
# META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!!
# META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!!
# TODO: use broadcast flag?
# TODO: use broadcast flag?
...
@@ -5862,7 +5982,7 @@ local_elemwise_fusion = local_elemwise_fusion_op(T.Elemwise,
...
@@ -5862,7 +5982,7 @@ local_elemwise_fusion = local_elemwise_fusion_op(T.Elemwise,
class
FusionOptimizer
(
Optimizer
):
class
FusionOptimizer
(
Optimizer
):
"""Graph optimizer for Fusion of elemwise operations"""
"""Graph optimizer for Fusion of elemwise operations
.
"""
def
__init__
(
self
,
local_optimizer
):
def
__init__
(
self
,
local_optimizer
):
Optimizer
.
__init__
(
self
)
Optimizer
.
__init__
(
self
)
self
.
optimizer
=
local_optimizer
self
.
optimizer
=
local_optimizer
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论