提交 f6141a60 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/tensor/opt.py

上级 7f312182
""" """
Tensor optimizations addressing the ops in basic.py Tensor optimizations addressing the ops in basic.py.
""" """
from __future__ import print_function from __future__ import print_function
# TODO: intelligent merge for mul/add # TODO: intelligent merge for mul/add
...@@ -68,15 +68,20 @@ def copy_stack_trace(from_var, to_var): ...@@ -68,15 +68,20 @@ def copy_stack_trace(from_var, to_var):
Copies the stack trace from one or more tensor variables to Copies the stack trace from one or more tensor variables to
one or more tensor variables. one or more tensor variables.
:param from_var: tensor variable or list of tensor variables to Parameters
copy stack traces from. ----------
:param to_var: tensor variable or list of tensor variables to from_var
copy stack traces to. Tensor variable or list of tensor variables to copy stack traces from.
to_var
Tensor variable or list of tensor variables to copy stack traces to.
.. note:: The stacktrace is assumed to be of the form of a list of lists Notes
-----
The stacktrace is assumed to be of the form of a list of lists
of tuples. Each tuple contains the filename, line number, function name of tuples. Each tuple contains the filename, line number, function name
and so on. Each list of tuples contains the truples belonging to a and so on. Each list of tuples contains the truples belonging to a
particular variable. particular variable.
""" """
# Store stack traces from from_var # Store stack traces from from_var
...@@ -151,11 +156,18 @@ def _fill_chain(new_out, orig_inputs): ...@@ -151,11 +156,18 @@ def _fill_chain(new_out, orig_inputs):
def encompasses_broadcastable(b1, b2): def encompasses_broadcastable(b1, b2):
""" """
Returns True if the broadcastable patterns b1 and b2 are such that b2 is Parameters
----------
b1
The broadcastable attribute of a tensor type.
b2
The broadcastable attribute of a tensor type.
Returns
-------
True if the broadcastable patterns b1 and b2 are such that b2 is
broadcasted to b1's shape and not the opposite. broadcasted to b1's shape and not the opposite.
:param b1: the broadcastable attribute of a tensor type
:param b2: the broadcastable attribute of a tensor type
""" """
if len(b1) < len(b2): if len(b1) < len(b2):
return False return False
...@@ -184,7 +196,8 @@ def scalarconsts_rest(inputs): ...@@ -184,7 +196,8 @@ def scalarconsts_rest(inputs):
def broadcast_like(value, template, fgraph, dtype=None): def broadcast_like(value, template, fgraph, dtype=None):
"""Return a Variable with the same shape and dtype as the template, """
Return a Variable with the same shape and dtype as the template,
filled by broadcasting value through it. `value` will be cast as filled by broadcasting value through it. `value` will be cast as
necessary. necessary.
...@@ -240,9 +253,11 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -240,9 +253,11 @@ def inplace_elemwise_optimizer_op(OP):
see if it can operate inplace on that input. If so, makes the see if it can operate inplace on that input. If so, makes the
change and go to the next output or Broadcast Op. change and go to the next output or Broadcast Op.
Examples: Examples
--------
x + y + z -> x += y += z x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y) (x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
""" """
# We should not validate too often as this takes too much time to # We should not validate too often as this takes too much time to
# execute! # execute!
...@@ -507,6 +522,7 @@ def local_dimshuffle_lift(node): ...@@ -507,6 +522,7 @@ def local_dimshuffle_lift(node):
After this transform, clusters of Elemwise operations are After this transform, clusters of Elemwise operations are
void of DimShuffle operations. void of DimShuffle operations.
""" """
op = node.op op = node.op
if not isinstance(op, DimShuffle): if not isinstance(op, DimShuffle):
...@@ -556,6 +572,7 @@ def local_lift_transpose_through_dot(node): ...@@ -556,6 +572,7 @@ def local_lift_transpose_through_dot(node):
The transformation should be apply whether or not the transpose is The transformation should be apply whether or not the transpose is
inplace. The newly-introduced transpositions are not inplace, this will inplace. The newly-introduced transpositions are not inplace, this will
be taken care of in a later optimization phase. be taken care of in a later optimization phase.
""" """
if not (isinstance(node.op, T.DimShuffle) and node.op.new_order == (1, 0)): if not (isinstance(node.op, T.DimShuffle) and node.op.new_order == (1, 0)):
return False return False
...@@ -639,11 +656,12 @@ def local_scalar_tensor_scalar(node): ...@@ -639,11 +656,12 @@ def local_scalar_tensor_scalar(node):
class MakeVector(T.Op): class MakeVector(T.Op):
"""Concatenate a number of scalars together into a vector """Concatenate a number of scalars together into a vector.
This is a simple version of stack() that introduces far less cruft This is a simple version of stack() that introduces far less cruft
into the graph. Should work with 0 inputs. The constant_folding into the graph. Should work with 0 inputs. The constant_folding
optimization will remove it. optimization will remove it.
""" """
__props__ = ("dtype",) __props__ = ("dtype",)
...@@ -755,7 +773,7 @@ T.pprint.assign(lambda pstate, r: r.owner and ...@@ -755,7 +773,7 @@ T.pprint.assign(lambda pstate, r: r.owner and
class ShapeFeature(object): class ShapeFeature(object):
"""Graph optimizer for removing all calls to shape() """Graph optimizer for removing all calls to shape().
This optimizer replaces all Shapes and Subtensors of Shapes with This optimizer replaces all Shapes and Subtensors of Shapes with
Shape_i and MakeVector Ops. Shape_i and MakeVector Ops.
...@@ -791,7 +809,6 @@ class ShapeFeature(object): ...@@ -791,7 +809,6 @@ class ShapeFeature(object):
For example the infer_shape for a matrix-matrix product would accept For example the infer_shape for a matrix-matrix product would accept
input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),). input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),).
Inferring the shape of internal nodes in the graph is important Inferring the shape of internal nodes in the graph is important
for doing size-driven optimizations. If we know how big various for doing size-driven optimizations. If we know how big various
intermediate results will be, we can estimate the cost of many Ops intermediate results will be, we can estimate the cost of many Ops
...@@ -800,8 +817,8 @@ class ShapeFeature(object): ...@@ -800,8 +817,8 @@ class ShapeFeature(object):
In cases where you cannot figure out the shape, raise a ShapeError. In cases where you cannot figure out the shape, raise a ShapeError.
.. note:: Notes
-----
Right now there is only the ConvOp that could really take Right now there is only the ConvOp that could really take
advantage of this shape inference, but it is worth it even advantage of this shape inference, but it is worth it even
just for the ConvOp. All that's necessary to do shape just for the ConvOp. All that's necessary to do shape
...@@ -842,7 +859,7 @@ class ShapeFeature(object): ...@@ -842,7 +859,7 @@ class ShapeFeature(object):
""" """
def shape_ir(self, i, r): def shape_ir(self, i, r):
"""Return symbolic r.shape[i] for tensor variable r, int i""" """Return symbolic r.shape[i] for tensor variable r, int i."""
if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]: if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]:
return self.lscalar_one return self.lscalar_one
else: else:
...@@ -855,7 +872,7 @@ class ShapeFeature(object): ...@@ -855,7 +872,7 @@ class ShapeFeature(object):
return s return s
def shape_tuple(self, r): def shape_tuple(self, r):
"""Return a tuple of symbolic shape vars for tensor variable r""" """Return a tuple of symbolic shape vars for tensor variable r."""
if not hasattr(r, 'ndim'): if not hasattr(r, 'ndim'):
# This happen for NoneConst. # This happen for NoneConst.
return None return None
...@@ -867,6 +884,7 @@ class ShapeFeature(object): ...@@ -867,6 +884,7 @@ class ShapeFeature(object):
This function is used for Ops that don't implement infer_shape. This function is used for Ops that don't implement infer_shape.
Ops that do implement infer_shape should use the i_shapes parameter, Ops that do implement infer_shape should use the i_shapes parameter,
but this default implementation ignores it. but this default implementation ignores it.
""" """
rval = [] rval = []
for r in node.outputs: for r in node.outputs:
...@@ -880,6 +898,7 @@ class ShapeFeature(object): ...@@ -880,6 +898,7 @@ class ShapeFeature(object):
"""Return a symbolic integer scalar for the shape element s_i. """Return a symbolic integer scalar for the shape element s_i.
The s_i argument was produced by the infer_shape() of an Op subclass. The s_i argument was produced by the infer_shape() of an Op subclass.
""" """
# unpack the s_i that the Op returned # unpack the s_i that the Op returned
assert s_i is not None assert s_i is not None
...@@ -933,8 +952,11 @@ class ShapeFeature(object): ...@@ -933,8 +952,11 @@ class ShapeFeature(object):
def set_shape(self, r, s): def set_shape(self, r, s):
"""Assign the shape `s` to previously un-shaped variable `r`. """Assign the shape `s` to previously un-shaped variable `r`.
:type r: a variable Parameters
:type s: None or a tuple of symbolic integers ----------
r : a variable
s : None or a tuple of symbolic integers
""" """
assert r not in self.shape_of, 'r already in shape_of' assert r not in self.shape_of, 'r already in shape_of'
if s is None: if s is None:
...@@ -972,11 +994,12 @@ class ShapeFeature(object): ...@@ -972,11 +994,12 @@ class ShapeFeature(object):
self.shape_of_reverse_index.setdefault(sv, set()).add(r) self.shape_of_reverse_index.setdefault(sv, set()).add(r)
def update_shape(self, r, other_r): def update_shape(self, r, other_r):
'''Replace shape of r by shape of other_r. """Replace shape of r by shape of other_r.
If, on some dimensions, the shape of other_r is not informative, If, on some dimensions, the shape of other_r is not informative,
keep the shape of r on those dimensions. keep the shape of r on those dimensions.
'''
"""
# other_r should already have a shape # other_r should already have a shape
assert other_r in self.shape_of, ('other_r not in shape_of', other_r) assert other_r in self.shape_of, ('other_r not in shape_of', other_r)
other_shape = self.shape_of[other_r] other_shape = self.shape_of[other_r]
...@@ -1303,8 +1326,7 @@ class ShapeFeature(object): ...@@ -1303,8 +1326,7 @@ class ShapeFeature(object):
class ShapeOptimizer(Optimizer): class ShapeOptimizer(Optimizer):
"""Optimizer that serves to add ShapeFeature as an fgraph feature. """Optimizer that serves to add ShapeFeature as an fgraph feature."""
"""
def __init__(self): def __init__(self):
Optimizer.__init__(self) Optimizer.__init__(self)
...@@ -1392,6 +1414,7 @@ def local_useless_alloc(node): ...@@ -1392,6 +1414,7 @@ def local_useless_alloc(node):
If the input type is the same as the output type (dtype and broadcast) If the input type is the same as the output type (dtype and broadcast)
there is no change in the shape of the input. So this is just a simple copy there is no change in the shape of the input. So this is just a simple copy
of the input. This is not needed. of the input. This is not needed.
""" """
if node.op == T.alloc: if node.op == T.alloc:
if node.inputs[0].type == node.outputs[0].type: if node.inputs[0].type == node.outputs[0].type:
...@@ -1438,14 +1461,15 @@ def local_track_shape_i(node): ...@@ -1438,14 +1461,15 @@ def local_track_shape_i(node):
@gof.local_optimizer([Subtensor, AdvancedSubtensor1]) @gof.local_optimizer([Subtensor, AdvancedSubtensor1])
def local_subtensor_make_vector(node): def local_subtensor_make_vector(node):
""" """
replace all subtensor(make_vector) like: Replace all subtensor(make_vector) like:
[a,b,c][0] -> a [a,b,c][0] -> a
[a,b,c][0:2] -> [a,b] [a,b,c][0:2] -> [a,b]
replace all AdvancedSubtensor1(make_vector) like: Replace all AdvancedSubtensor1(make_vector) like:
[a,b,c][[0,2]] -> [a,c] [a,b,c][[0,2]] -> [a,c]
we can do this for constant indexes We can do this for constant indexes.
""" """
x = node.inputs[0] x = node.inputs[0]
if not x.owner or x.owner.op != make_vector: if not x.owner or x.owner.op != make_vector:
...@@ -1514,7 +1538,6 @@ def local_subtensor_make_vector(node): ...@@ -1514,7 +1538,6 @@ def local_subtensor_make_vector(node):
@gof.local_optimizer([T.Elemwise]) @gof.local_optimizer([T.Elemwise])
def local_useless_elemwise(node): def local_useless_elemwise(node):
""" """
eq(x,x) -> 1 eq(x,x) -> 1
neq(x,x) -> 0 neq(x,x) -> 0
mul(x) -> x mul(x) -> x
...@@ -1559,8 +1582,7 @@ def local_useless_elemwise(node): ...@@ -1559,8 +1582,7 @@ def local_useless_elemwise(node):
@register_specialize @register_specialize
@gof.local_optimizer([T.Elemwise]) @gof.local_optimizer([T.Elemwise])
def local_alloc_unary(node): def local_alloc_unary(node):
"""unary(alloc(x, shp)) -> alloc(unary(x), shp) """unary(alloc(x, shp)) -> alloc(unary(x), shp)"""
"""
if isinstance(node.op, T.Elemwise) and len(node.inputs) == 1: if isinstance(node.op, T.Elemwise) and len(node.inputs) == 1:
a = node.inputs[0] a = node.inputs[0]
if a.owner and isinstance(a.owner.op, T.Alloc): if a.owner and isinstance(a.owner.op, T.Alloc):
...@@ -1587,6 +1609,7 @@ def local_cast_cast(node): ...@@ -1587,6 +1609,7 @@ def local_cast_cast(node):
dtype1 == dtype2 dtype1 == dtype2
TODO: the base dtype is the same (int, uint, float, complex) TODO: the base dtype is the same (int, uint, float, complex)
and the first cast cause an upcast. and the first cast cause an upcast.
""" """
if (not isinstance(node.op, T.Elemwise) or if (not isinstance(node.op, T.Elemwise) or
not isinstance(node.op.scalar_op, scalar.Cast)): not isinstance(node.op.scalar_op, scalar.Cast)):
...@@ -1607,9 +1630,9 @@ def local_cast_cast(node): ...@@ -1607,9 +1630,9 @@ def local_cast_cast(node):
def local_func_inv(node): def local_func_inv(node):
""" """
Check for two consecutive operations that are functional inverses Check for two consecutive operations that are functional inverses
and remove them from the function graph and remove them from the function graph.
"""
"""
inv_pairs = ( inv_pairs = (
(basic.Deg2Rad, basic.Rad2Deg), (basic.Deg2Rad, basic.Rad2Deg),
(basic.Cosh, basic.ArcCosh), (basic.Cosh, basic.ArcCosh),
...@@ -1641,9 +1664,9 @@ def local_func_inv(node): ...@@ -1641,9 +1664,9 @@ def local_func_inv(node):
def is_inverse_pair(node_op, prev_op, inv_pair): def is_inverse_pair(node_op, prev_op, inv_pair):
""" """
Given two consecutive operations, check if they are the Given two consecutive operations, check if they are the
provided pair of inverse functions provided pair of inverse functions.
"""
"""
node_is_op0 = isinstance(node_op, inv_pair[0]) node_is_op0 = isinstance(node_op, inv_pair[0])
node_is_op1 = isinstance(node_op, inv_pair[1]) node_is_op1 = isinstance(node_op, inv_pair[1])
prev_is_op0 = isinstance(prev_op, inv_pair[0]) prev_is_op0 = isinstance(prev_op, inv_pair[0])
...@@ -1659,20 +1682,24 @@ class Assert(T.Op): ...@@ -1659,20 +1682,24 @@ class Assert(T.Op):
Returns the first parameter if the condition is true, otherwise, triggers Returns the first parameter if the condition is true, otherwise, triggers
AssertionError. AssertionError.
Example: Notes
T = theano.tensor -----
x = T.vector('x')
assert_op = T.opt.Assert()
func = theano.function([x], assert_op(x, x.size<2))
Notes:
This Op is a debugging feature. It can be removed from the graph This Op is a debugging feature. It can be removed from the graph
because of optimizations, and can hide some possible optimizations to because of optimizations, and can hide some possible optimizations to
the optimizer. Specifically, removing happens if it can be determined the optimizer. Specifically, removing happens if it can be determined
that condition will always be true. Also, the output of the Op must be that condition will always be true. Also, the output of the Op must be
used in the function computing the graph, but it doesn't have to be used in the function computing the graph, but it doesn't have to be
returned. returned.
Examples
--------
T = theano.tensor
x = T.vector('x')
assert_op = T.opt.Assert()
func = theano.function([x], assert_op(x, x.size<2))
""" """
__props__ = ('msg',) __props__ = ('msg',)
view_map = {0: [0]} view_map = {0: [0]}
...@@ -1770,7 +1797,9 @@ def local_remove_all_assert(node): ...@@ -1770,7 +1797,9 @@ def local_remove_all_assert(node):
"""An optimization disabled by default that removes all asserts from """An optimization disabled by default that removes all asserts from
the graph. the graph.
:note: See the :ref:`unsafe` section to know how to enable it. Notes
-----
See the :ref:`unsafe` section to know how to enable it.
""" """
if not isinstance(node.op, Assert): if not isinstance(node.op, Assert):
...@@ -1804,11 +1833,12 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1804,11 +1833,12 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
BROADCAST CONDITION: the condition is that the one input that are BROADCAST CONDITION: the condition is that the one input that are
not to be optimized to have the same broadcast pattern as the not to be optimized to have the same broadcast pattern as the
output output.
We can change the alloc by a dimshuffle as the elemwise We can change the alloc by a dimshuffle as the elemwise
already have the shape info. The dimshuffle will be faster already have the shape info. The dimshuffle will be faster
to exec to exec.
""" """
if not isinstance(node.op, ElemwiseOP): if not isinstance(node.op, ElemwiseOP):
return False return False
...@@ -1969,6 +1999,7 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -1969,6 +1999,7 @@ def local_upcast_elemwise_constant_inputs(node):
those Ops do implicit upcasting anyway. those Ops do implicit upcasting anyway.
Rationale: it helps merge things like (1-x) and (1.0 - x). Rationale: it helps merge things like (1-x) and (1.0 - x).
""" """
if len(node.outputs) > 1: if len(node.outputs) > 1:
return return
...@@ -2033,7 +2064,8 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -2033,7 +2064,8 @@ def local_upcast_elemwise_constant_inputs(node):
@register_specialize @register_specialize
@gof.local_optimizer([IncSubtensor]) @gof.local_optimizer([IncSubtensor])
def local_useless_inc_subtensor(node): def local_useless_inc_subtensor(node):
"""Remove IncSubtensor, when we overwrite the full inputs with the """
Remove IncSubtensor, when we overwrite the full inputs with the
new value. new value.
""" """
...@@ -2082,6 +2114,7 @@ def local_set_to_inc_subtensor(node): ...@@ -2082,6 +2114,7 @@ def local_set_to_inc_subtensor(node):
""" """
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) -> AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False) AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
""" """
if (isinstance(node.op, AdvancedIncSubtensor1) and if (isinstance(node.op, AdvancedIncSubtensor1) and
node.op.set_instead_of_inc and node.op.set_instead_of_inc and
...@@ -2144,6 +2177,7 @@ def local_useless_subtensor(node): ...@@ -2144,6 +2177,7 @@ def local_useless_subtensor(node):
AdvancedSubtensor1 case, the full input is taken when the indices are AdvancedSubtensor1 case, the full input is taken when the indices are
equivalent to `arange(0, input.shape[0], 1)` using either an explicit equivalent to `arange(0, input.shape[0], 1)` using either an explicit
list/vector or the ARange op. list/vector or the ARange op.
""" """
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(node.fgraph, 'shape_feature'): if not hasattr(node.fgraph, 'shape_feature'):
...@@ -2261,6 +2295,7 @@ def local_subtensor_lift(node): ...@@ -2261,6 +2295,7 @@ def local_subtensor_lift(node):
elemwise(x,...)[idx] -> elemwise(x[idx],...) elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all when x,... are broadcasted scalar or not broadcasted at all
rebroadcast(x)[idx] => rebroadcast(x[idx]) rebroadcast(x)[idx] => rebroadcast(x[idx])
""" """
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
u = node.inputs[0] u = node.inputs[0]
...@@ -2327,7 +2362,7 @@ def local_subtensor_lift(node): ...@@ -2327,7 +2362,7 @@ def local_subtensor_lift(node):
def merge_two_slices(slice1, len1, slice2, len2): def merge_two_slices(slice1, len1, slice2, len2):
''' """
This function merges two slices into a single slice. The code works on This function merges two slices into a single slice. The code works on
the assumption that: the assumption that:
a) slice1 is actually a slice and not an index, while slice2 a) slice1 is actually a slice and not an index, while slice2
...@@ -2340,7 +2375,7 @@ def merge_two_slices(slice1, len1, slice2, len2): ...@@ -2340,7 +2375,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
the two consecutive slices. the two consecutive slices.
``len1`` is the length of the tensor **before** applying the first slice, ``len1`` is the length of the tensor **before** applying the first slice,
while ``len2`` is the length **after** applying the first slice. while ``len2`` is the length **after** applying the first slice.
''' """
list_opt = [local_abs_merge, local_mul_switch_sink, list_opt = [local_abs_merge, local_mul_switch_sink,
local_upcast_elemwise_constant_inputs, local_upcast_elemwise_constant_inputs,
local_remove_switch_const_cond, constant_folding] local_remove_switch_const_cond, constant_folding]
...@@ -2466,6 +2501,7 @@ def local_subtensor_merge(node): ...@@ -2466,6 +2501,7 @@ def local_subtensor_merge(node):
Refactored optimization to deal with all cases of tensor merging. Refactored optimization to deal with all cases of tensor merging.
Given a subgraph of the form Subtensor(Subtensor(u)), the optimization Given a subgraph of the form Subtensor(Subtensor(u)), the optimization
expresses all slices in a canonical form, and then merges them together. expresses all slices in a canonical form, and then merges them together.
""" """
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
...@@ -2601,7 +2637,8 @@ def local_subtensor_of_dot(node): ...@@ -2601,7 +2637,8 @@ def local_subtensor_of_dot(node):
idxs_a is the first A.ndim-1 entries of idxs, idxs_a is the first A.ndim-1 entries of idxs,
and idxs_b is the remaining entries of idxs (if any), and idxs_b is the remaining entries of idxs (if any),
modified to skip the second-to-last dimension of B modified to skip the second-to-last dimension of B
(because dot sums over this dimension) (because dot sums over this dimension).
""" """
if not isinstance(node.op, Subtensor): if not isinstance(node.op, Subtensor):
return return
...@@ -2715,7 +2752,8 @@ compile.optdb.register('pre_local_IncSubtensor_serialize', ...@@ -2715,7 +2752,8 @@ compile.optdb.register('pre_local_IncSubtensor_serialize',
@gof.local_optimizer([IncSubtensor], inplace=True) @gof.local_optimizer([IncSubtensor], inplace=True)
def local_inplace_setsubtensor(node): def local_inplace_setsubtensor(node):
""" """
Also work for GpuIncSubtensor Also work for GpuIncSubtensor.
""" """
if isinstance(node.op, IncSubtensor) and not node.op.inplace: if isinstance(node.op, IncSubtensor) and not node.op.inplace:
new_op = node.op.__class__( new_op = node.op.__class__(
...@@ -2734,7 +2772,10 @@ compile.optdb.register('local_inplace_setsubtensor', ...@@ -2734,7 +2772,10 @@ compile.optdb.register('local_inplace_setsubtensor',
@gof.local_optimizer([AdvancedIncSubtensor1], inplace=True) @gof.local_optimizer([AdvancedIncSubtensor1], inplace=True)
def local_inplace_incsubtensor1(node): def local_inplace_incsubtensor1(node):
""" also work for GpuAdvancedIncSubtensor1 """ """
Also work for GpuAdvancedIncSubtensor1.
"""
if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace: if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace:
new_op = node.op.clone_inplace() new_op = node.op.clone_inplace()
new_node = new_op(*node.inputs) new_node = new_op(*node.inputs)
...@@ -2756,6 +2797,7 @@ compile.optdb.register('local_inplace_incsubtensor1', ...@@ -2756,6 +2797,7 @@ compile.optdb.register('local_inplace_incsubtensor1',
def local_incsubtensor_of_zeros(node): def local_incsubtensor_of_zeros(node):
""" """
IncSubtensor(x, zeros, idx) -> x IncSubtensor(x, zeros, idx) -> x
""" """
if (isinstance(node.op, (IncSubtensor, if (isinstance(node.op, (IncSubtensor,
AdvancedIncSubtensor, AdvancedIncSubtensor,
...@@ -2784,6 +2826,7 @@ def local_setsubtensor_of_constants(node): ...@@ -2784,6 +2826,7 @@ def local_setsubtensor_of_constants(node):
SetSubtensor(x, x[idx], idx) -> x SetSubtensor(x, x[idx], idx) -> x
when x is constant or alloc. when x is constant or alloc.
""" """
if isinstance(node.op, IncSubtensor) and node.op.set_instead_of_inc: if isinstance(node.op, IncSubtensor) and node.op.set_instead_of_inc:
x = node.inputs[0] x = node.inputs[0]
...@@ -2813,12 +2856,14 @@ def local_setsubtensor_of_constants(node): ...@@ -2813,12 +2856,14 @@ def local_setsubtensor_of_constants(node):
@register_stabilize @register_stabilize
@gof.local_optimizer([AdvancedSubtensor1]) @gof.local_optimizer([AdvancedSubtensor1])
def local_adv_sub1_adv_inc_sub1(node): def local_adv_sub1_adv_inc_sub1(node):
"""Optimize the possible AdvSub1(AdvIncSub1(...), ...) """Optimize the possible AdvSub1(AdvIncSub1(...), ...).
AdvancedSubtensor1(AdvancedIncSubtensor1(0s, y, idx), idx) -> y AdvancedSubtensor1(AdvancedIncSubtensor1(0s, y, idx), idx) -> y
AdvancedSubtensor1(AdvancedSetSubtensor1(x, y, idx), idx) -> y AdvancedSubtensor1(AdvancedSetSubtensor1(x, y, idx), idx) -> y
:note: This opt add AssertOp. Otherwise, it would remove shape and Notes
-----
This opt add AssertOp. Otherwise, it would remove shape and
index error. If you want to get rid of them, see the index error. If you want to get rid of them, see the
:ref:`unsafe_optimization` section. :ref:`unsafe_optimization` section.
...@@ -2862,6 +2907,7 @@ def local_useless_inc_subtensor_alloc(node): ...@@ -2862,6 +2907,7 @@ def local_useless_inc_subtensor_alloc(node):
Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of
a fully or partially broadcastable variable, by one that skips the a fully or partially broadcastable variable, by one that skips the
intermediate `alloc` where possible. intermediate `alloc` where possible.
""" """
if isinstance(node.op, (IncSubtensor, if isinstance(node.op, (IncSubtensor,
AdvancedIncSubtensor, AdvancedIncSubtensor,
...@@ -2962,7 +3008,8 @@ def local_useless_inc_subtensor_alloc(node): ...@@ -2962,7 +3008,8 @@ def local_useless_inc_subtensor_alloc(node):
@gof.local_optimizer([T.Rebroadcast]) @gof.local_optimizer([T.Rebroadcast])
def local_useless_rebroadcast(node): def local_useless_rebroadcast(node):
""" """
Remove Rebroadcast if id does not actually change the broadcasting pattern Remove Rebroadcast if id does not actually change the broadcasting pattern.
""" """
if isinstance(node.op, T.Rebroadcast): if isinstance(node.op, T.Rebroadcast):
x = node.inputs[0] x = node.inputs[0]
...@@ -2992,6 +3039,7 @@ def local_rebroadcast_lift(node): ...@@ -2992,6 +3039,7 @@ def local_rebroadcast_lift(node):
Rebroadcast(Elemwise(x)) => Elemwise(Rebroadcast(x)) Rebroadcast(Elemwise(x)) => Elemwise(Rebroadcast(x))
Rebroadcast(Rebroadcast(x)) => Rebroadcast(x) Rebroadcast(Rebroadcast(x)) => Rebroadcast(x)
""" """
op = node.op op = node.op
if not isinstance(op, T.Rebroadcast): if not isinstance(op, T.Rebroadcast):
...@@ -3023,8 +3071,14 @@ def apply_rebroadcast_opt(rval): ...@@ -3023,8 +3071,14 @@ def apply_rebroadcast_opt(rval):
Apply as many times as required the optimization local_useless_rebroadcast Apply as many times as required the optimization local_useless_rebroadcast
and local_rebroadcast_lift. and local_rebroadcast_lift.
:param rval: a Variable Parameters
:return: a Variable (the same if no optimization can be applied) ----------
rval: a Variable
Returns
-------
A Variable (the same if no optimization can be applied)
""" """
changed = True changed = True
...@@ -3056,6 +3110,7 @@ def local_join_1(node): ...@@ -3056,6 +3110,7 @@ def local_join_1(node):
"""Join(i, x) => x """Join(i, x) => x
Remove Join() when only one element is joined. Remove Join() when only one element is joined.
""" """
if not isinstance(node.op, T.Join): if not isinstance(node.op, T.Join):
return return
...@@ -3070,7 +3125,8 @@ def local_join_1(node): ...@@ -3070,7 +3125,8 @@ def local_join_1(node):
def local_join_empty(node): def local_join_empty(node):
"""Join(i, x, y, empty) => Join(i, x, y) """Join(i, x, y, empty) => Join(i, x, y)
remove empty inputs to joins. The empty inputs can be anywhere. Remove empty inputs to joins. The empty inputs can be anywhere.
""" """
if not isinstance(node.op, T.Join): if not isinstance(node.op, T.Join):
return return
...@@ -3147,6 +3203,7 @@ def local_remove_switch_const_cond(node): ...@@ -3147,6 +3203,7 @@ def local_remove_switch_const_cond(node):
T.switch(cond,left,right) --> T.switch(cond,left,right) -->
if cond is constant and cond == 0: right if cond is constant and cond == 0: right
if cond is constant and cond != 0: left if cond is constant and cond != 0: left
""" """
if (isinstance(node.op, T.Elemwise) and if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch)): isinstance(node.op.scalar_op, scalar.basic.Switch)):
...@@ -3183,7 +3240,9 @@ def local_mul_switch_sink(node): ...@@ -3183,7 +3240,9 @@ def local_mul_switch_sink(node):
This is useful because A and B may not be numerically stable and give This is useful because A and B may not be numerically stable and give
NaN or inf values for cases where the switch returns 0. NaN or inf values for cases where the switch returns 0.
With this optimization T.grad(T.switch(...)) has the right behavior. With this optimization T.grad(T.switch(...)) has the right behavior.
Exemple:
Examples
--------
x -> f(x) x -> f(x)
x -> g(x) x -> g(x)
y = T.switch(cond,f(x),g(x)) y = T.switch(cond,f(x),g(x))
...@@ -3193,6 +3252,7 @@ def local_mul_switch_sink(node): ...@@ -3193,6 +3252,7 @@ def local_mul_switch_sink(node):
T.grad(y,x) -> switch(cond,grad(f(x),x), 0) + switch(cond,0,grad(g(x),x)) T.grad(y,x) -> switch(cond,grad(f(x),x), 0) + switch(cond,0,grad(g(x),x))
This will be particularly useful for the lazyif because we skip This will be particularly useful for the lazyif because we skip
an entire part of the graph. an entire part of the graph.
""" """
if node.op != T.mul: if node.op != T.mul:
return False return False
...@@ -3234,6 +3294,7 @@ def local_div_switch_sink(node): ...@@ -3234,6 +3294,7 @@ def local_div_switch_sink(node):
This is useful because A may not be numerically stable and give This is useful because A may not be numerically stable and give
NaN or inf values for cases where the switch returns 0. NaN or inf values for cases where the switch returns 0.
See local_mul_switch_sink for more details. See local_mul_switch_sink for more details.
""" """
if (node.op != T.true_div and node.op != T.int_div): if (node.op != T.true_div and node.op != T.int_div):
return False return False
...@@ -3308,6 +3369,7 @@ def local_useless_split(node): ...@@ -3308,6 +3369,7 @@ def local_useless_split(node):
""" Split{n_splits=1}(x, y) -> x """ Split{n_splits=1}(x, y) -> x
Remove Split with only 1 split. Remove Split with only 1 split.
""" """
if isinstance(node.op, T.Split): if isinstance(node.op, T.Split):
if node.op.len_splits == 1: if node.op.len_splits == 1:
...@@ -3329,6 +3391,7 @@ def local_flatten_lift(node): ...@@ -3329,6 +3391,7 @@ def local_flatten_lift(node):
This optimization is needed by optimization This optimization is needed by optimization
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten. nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten.
""" """
if (isinstance(node.op, T.Flatten) and if (isinstance(node.op, T.Flatten) and
node.inputs[0].owner and node.inputs[0].owner and
...@@ -3347,6 +3410,7 @@ def local_flatten_lift(node): ...@@ -3347,6 +3410,7 @@ def local_flatten_lift(node):
def local_reshape_chain(node): def local_reshape_chain(node):
""" """
Reshape(Reshape(shape1),shape2) -> Reshape(shape2) Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
""" """
if not opt.check_chain(node, T.Reshape, T.Reshape): if not opt.check_chain(node, T.Reshape, T.Reshape):
return False return False
...@@ -3378,6 +3442,7 @@ def local_reshape_lift(node): ...@@ -3378,6 +3442,7 @@ def local_reshape_lift(node):
This optimization is needed by optimization This optimization is needed by optimization
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a reshape. nnet/sigm.py:log1msigm_to_softplus to get applied when there is a reshape.
""" """
if (isinstance(node.op, T.Reshape) and if (isinstance(node.op, T.Reshape) and
node.inputs[0].owner and node.inputs[0].owner and
...@@ -3526,15 +3591,20 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3526,15 +3591,20 @@ class Canonizer(gof.LocalOptimizer):
Usage: Canonizer(main, inverse, reciprocal, calculate) Usage: Canonizer(main, inverse, reciprocal, calculate)
* main: a suitable Op class that is commutative, associative and Parameters
----------
main
A suitable Op class that is commutative, associative and
takes one to an arbitrary number of inputs, e.g. add or takes one to an arbitrary number of inputs, e.g. add or
mul mul
* inverse: an Op class such that inverse(main(x, y), y) == x inverse
An Op class such that inverse(main(x, y), y) == x
e.g. sub or true_div e.g. sub or true_div
* reciprocal: a function such that main(x, reciprocal(y)) == reciprocal
inverse(x, y) e.g. neg or inv A function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. neg or inv
* calculate: function that takes a list of numpy.ndarray instances calculate
Function that takes a list of numpy.ndarray instances
for the numerator, another list for the denumerator, for the numerator, another list for the denumerator,
and calculates inverse(main(*num), main(*denum)). It and calculates inverse(main(*num), main(*denum)). It
takes a keyword argument, aslist. If True, the value takes a keyword argument, aslist. If True, the value
...@@ -3545,7 +3615,8 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3545,7 +3615,8 @@ class Canonizer(gof.LocalOptimizer):
The variable is a local_optimizer. It is best used with a TopoOptimizer in The variable is a local_optimizer. It is best used with a TopoOptimizer in
in_to_out order. in_to_out order.
Examples: Examples
--------
T = theano.tensor T = theano.tensor
add_canonizer = Canonizer(T.add, T.sub, T.neg, add_canonizer = Canonizer(T.add, T.sub, T.neg,
lambda n, d: sum(n) - sum(d)) lambda n, d: sum(n) - sum(d))
...@@ -3563,6 +3634,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3563,6 +3634,7 @@ class Canonizer(gof.LocalOptimizer):
2 * x / 2 -> x 2 * x / 2 -> x
x * y * z -> Elemwise(T.mul){x,y,z} #only one pass over the memory. x * y * z -> Elemwise(T.mul){x,y,z} #only one pass over the memory.
!-> Elemwise(T.mul){x,Elemwise(T.mul){y,z}} !-> Elemwise(T.mul){x,Elemwise(T.mul){y,z}}
""" """
def __init__(self, main, inverse, reciprocal, calculate, def __init__(self, main, inverse, reciprocal, calculate,
...@@ -3747,8 +3819,11 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3747,8 +3819,11 @@ class Canonizer(gof.LocalOptimizer):
@staticmethod @staticmethod
def get_constant(v): def get_constant(v):
""" """
Returns a numeric constant if v is a Constant or, well, a Returns
-------
A numeric constant if v is a Constant or, well, a
numeric constant. If v is a plain Variable, returns None. numeric constant. If v is a plain Variable, returns None.
""" """
if isinstance(v, Variable): if isinstance(v, Variable):
try: try:
...@@ -3762,6 +3837,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3762,6 +3837,7 @@ class Canonizer(gof.LocalOptimizer):
""" """
Shorthand for: Shorthand for:
self.simplify_constants(*self.simplify_factors(num, denum)) self.simplify_constants(*self.simplify_factors(num, denum))
""" """
rval = self.simplify_constants(*self.simplify_factors(num, denum), rval = self.simplify_constants(*self.simplify_factors(num, denum),
out_type=out_type) out_type=out_type)
...@@ -3781,6 +3857,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3781,6 +3857,7 @@ class Canonizer(gof.LocalOptimizer):
[x], [x] -> [], [] [x], [x] -> [], []
[x, y], [x] -> [y], [] [x, y], [x] -> [y], []
[a, b], [c, d] -> [a, b], [c, d] [a, b], [c, d] -> [a, b], [c, d]
""" """
for v in list(num): for v in list(num):
if v in denum: if v in denum:
...@@ -3790,18 +3867,22 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3790,18 +3867,22 @@ class Canonizer(gof.LocalOptimizer):
def simplify_constants(self, orig_num, orig_denum, out_type=None): def simplify_constants(self, orig_num, orig_denum, out_type=None):
""" """
Find all constants and put them together into a single constant.
Finds all constants in orig_num and orig_denum (using Finds all constants in orig_num and orig_denum (using
get_constant) and puts them together into a single get_constant) and puts them together into a single
constant. The constant is inserted as the first element of the constant. The constant is inserted as the first element of the
numerator. If the constant is the neutral element, it is numerator. If the constant is the neutral element, it is
removed from the numerator. Examples: removed from the numerator.
Examples
--------
Let main be multiplication: Let main be multiplication:
[2, 3, x], [] -> [6, x], [] [2, 3, x], [] -> [6, x], []
[x, y, 2], [4, z] -> [0.5, x, y], [z] [x, y, 2], [4, z] -> [0.5, x, y], [z]
[x, 2, y], [z, 2] -> [x, y], [z] [x, 2, y], [z, 2] -> [x, y], [z]
""" """
# Lists representing the numerator and denumerator # Lists representing the numerator and denumerator
...@@ -3969,13 +4050,15 @@ register_canonicalize(local_neg_to_mul) ...@@ -3969,13 +4050,15 @@ register_canonicalize(local_neg_to_mul)
@register_specialize @register_specialize
@gof.local_optimizer([T.Sum, T.elemwise.Prod]) @gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_sum_prod_mul_by_scalar(node): def local_sum_prod_mul_by_scalar(node):
"""sum(scalar * smth) -> scalar * sum(smth) """
sum(scalar * smth) -> scalar * sum(smth)
sum(-smth) -> -sum(smth) sum(-smth) -> -sum(smth)
or or
prod(scalar * smth) -> scalar ** size(smth) * prod(smth) prod(scalar * smth) -> scalar ** size(smth) * prod(smth)
prod(-smth) -> -1 ** size(smth) * prod(smth) prod(-smth) -> -1 ** size(smth) * prod(smth)
""" """
# TODO: if the the thing inside the Sum is a division, # TODO: if the the thing inside the Sum is a division,
# we should get at the numerator.... # we should get at the numerator....
...@@ -4040,8 +4123,11 @@ def local_elemwise_sub_zeros(node): ...@@ -4040,8 +4123,11 @@ def local_elemwise_sub_zeros(node):
@register_specialize @register_specialize
@gof.local_optimizer([T.Sum]) @gof.local_optimizer([T.Sum])
def local_sum_div_dimshuffle(node): def local_sum_div_dimshuffle(node):
'''sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b, """
if dimension l of the DimShuffle is 'x'.''' sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
if dimension l of the DimShuffle is 'x'.
"""
# TODO: extend it to product, and quotient of products # TODO: extend it to product, and quotient of products
# It does not make much sense now to extend it to the case where the # It does not make much sense now to extend it to the case where the
...@@ -4128,8 +4214,10 @@ def local_sum_div_dimshuffle(node): ...@@ -4128,8 +4214,10 @@ def local_sum_div_dimshuffle(node):
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.Sum, T.elemwise.Prod]) @gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_sum_prod_all_to_none(node): def local_sum_prod_all_to_none(node):
"""Sum{0,1,...N} -> Sum{} or """
Sum{0,1,...N} -> Sum{} or
Prod{0,1,...N} -> Prod{} Prod{0,1,...N} -> Prod{}
""" """
if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod): if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod):
opt_type = T.Sum if isinstance(node.op, T.Sum) else T.elemwise.Prod opt_type = T.Sum if isinstance(node.op, T.Sum) else T.elemwise.Prod
...@@ -4148,6 +4236,7 @@ def local_op_of_op(node): ...@@ -4148,6 +4236,7 @@ def local_op_of_op(node):
Prod(Prod()) -> single Prod() Prod(Prod()) -> single Prod()
or or
Sum(Sum()) -> single Sum() Sum(Sum()) -> single Sum()
""" """
if isinstance(node.op, T.elemwise.Prod) or isinstance(node.op, T.Sum): if isinstance(node.op, T.elemwise.Prod) or isinstance(node.op, T.Sum):
opt_type = T.Sum if isinstance(node.op, T.Sum) else T.elemwise.Prod opt_type = T.Sum if isinstance(node.op, T.Sum) else T.elemwise.Prod
...@@ -4219,14 +4308,16 @@ ALL_REDUCE = [T.elemwise.CAReduce, T.elemwise.All, T.elemwise.Any, ...@@ -4219,14 +4308,16 @@ ALL_REDUCE = [T.elemwise.CAReduce, T.elemwise.All, T.elemwise.Any,
@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce @register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce
@gof.local_optimizer(ALL_REDUCE) @gof.local_optimizer(ALL_REDUCE)
def local_reduce_join(node): def local_reduce_join(node):
"""Reduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b) """
Reduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
:note: supported scalar.op are Maximum, Mimimum in some cases and Notes
Add and Mul in all cases. -----
Supported scalar.op are Maximum, Mimimum in some cases and Add and Mul in
all cases.
:note: Currently we must reduce on axis 0. It is probably Currently we must reduce on axis 0. It is probably extensible to the case
extensible to the case where we join and reduce on the same where we join and reduce on the same set of axis.
set of axis.
""" """
if (isinstance(node.op, T.CAReduce) and if (isinstance(node.op, T.CAReduce) and
...@@ -4312,7 +4403,7 @@ def local_cut_useless_reduce(node): ...@@ -4312,7 +4403,7 @@ def local_cut_useless_reduce(node):
@register_specialize @register_specialize
@gof.local_optimizer(ALL_REDUCE) @gof.local_optimizer(ALL_REDUCE)
def local_reduce_broadcastable(node): def local_reduce_broadcastable(node):
"""Remove reduction over broadcastable dimensions""" """Remove reduction over broadcastable dimensions."""
if isinstance(node.op, T.CAReduce): if isinstance(node.op, T.CAReduce):
reduced, = node.inputs reduced, = node.inputs
odtype = node.outputs[0].dtype odtype = node.outputs[0].dtype
...@@ -4351,9 +4442,11 @@ def local_reduce_broadcastable(node): ...@@ -4351,9 +4442,11 @@ def local_reduce_broadcastable(node):
@register_specialize @register_specialize
@gof.local_optimizer([T.Sum, T.elemwise.Prod]) @gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_opt_alloc(node): def local_opt_alloc(node):
""" sum(alloc(constant,shapes...)) => constant*prod(shapes) """
sum(alloc(constant,shapes...)) => constant*prod(shapes)
or or
prod(alloc(constant,shapes...)) => constant**prod(shapes) prod(alloc(constant,shapes...)) => constant**prod(shapes)
""" """
if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod): if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod):
node_inps, = node.inputs node_inps, = node.inputs
...@@ -4406,9 +4499,11 @@ def local_neg_neg(node): ...@@ -4406,9 +4499,11 @@ def local_neg_neg(node):
@register_specialize @register_specialize
@gof.local_optimizer([T.neg]) @gof.local_optimizer([T.neg])
def local_neg_div_neg(node): def local_neg_div_neg(node):
"""- (-a / b) -> a / b """
- (-a / b) -> a / b
Also performs - (c / b) -> ((-c) / b) when c is a scalar constant. Also performs - (c / b) -> ((-c) / b) when c is a scalar constant.
""" """
if node.op == T.neg: if node.op == T.neg:
if node.inputs[0].owner and node.inputs[0].owner.op == T.true_div: if node.inputs[0].owner and node.inputs[0].owner.op == T.true_div:
...@@ -4427,8 +4522,10 @@ def local_neg_div_neg(node): ...@@ -4427,8 +4522,10 @@ def local_neg_div_neg(node):
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_zero(node): def local_mul_zero(node):
"""As part of canonicalization, we replace multiplication by zero """
As part of canonicalization, we replace multiplication by zero
with zero. with zero.
""" """
if node.op == T.mul: if node.op == T.mul:
otype = node.outputs[0].type otype = node.outputs[0].type
...@@ -4489,10 +4586,12 @@ register_canonicalize(local_pow_canonicalize) ...@@ -4489,10 +4586,12 @@ register_canonicalize(local_pow_canonicalize)
@register_specialize @register_specialize
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_to_sqr(node): def local_mul_to_sqr(node):
"""x*x -> sqr(x) """
x*x -> sqr(x)
This is faster on the GPU when memory fetching is a big part of This is faster on the GPU when memory fetching is a big part of
the computation time. the computation time.
""" """
if node.op == T.mul: if node.op == T.mul:
if len(node.inputs) == 2: if len(node.inputs) == 2:
...@@ -4620,7 +4719,8 @@ def local_pow_specialize_device(node): ...@@ -4620,7 +4719,8 @@ def local_pow_specialize_device(node):
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_specialize(node): def local_mul_specialize(node):
"""Remove special-case constants from mul arguments and useless neg in inputs. """
Remove special-case constants from mul arguments and useless neg in inputs.
mul(-1, x) -> neg(x) mul(-1, x) -> neg(x)
mul(1, x, y) -> mul(x, y) mul(1, x, y) -> mul(x, y)
...@@ -4629,6 +4729,7 @@ def local_mul_specialize(node): ...@@ -4629,6 +4729,7 @@ def local_mul_specialize(node):
This is not done if we would add more nodes in the graph, like with: This is not done if we would add more nodes in the graph, like with:
mul(-1, x, y) -/-> neg(mul(x, y)) mul(-1, x, y) -/-> neg(mul(x, y))
""" """
# here, we are past the point of canonicalization, so we don't # here, we are past the point of canonicalization, so we don't
# want to put in un-necessary fills. # want to put in un-necessary fills.
...@@ -4766,8 +4867,9 @@ local_mul_canonizer.add_simplifier(check_for_x_over_absX, 'X_over_absX') ...@@ -4766,8 +4867,9 @@ local_mul_canonizer.add_simplifier(check_for_x_over_absX, 'X_over_absX')
@gof.local_optimizer([T.abs_]) @gof.local_optimizer([T.abs_])
def local_abs_lift(node): def local_abs_lift(node):
""" """
move the abs toward the input. This is needed for Move the abs toward the input.
check_for_x_over_absX to apply in more case.
This is needed for check_for_x_over_absX to apply in more case.
""" """
if node.op == T.abs_ and node.inputs[0].owner: if node.op == T.abs_ and node.inputs[0].owner:
...@@ -4783,7 +4885,7 @@ def local_abs_lift(node): ...@@ -4783,7 +4885,7 @@ def local_abs_lift(node):
@gof.local_optimizer([T.mul, T.true_div]) @gof.local_optimizer([T.mul, T.true_div])
def local_abs_merge(node): def local_abs_merge(node):
""" """
merge abs generated by local_abs_lift when the canonizer don't Merge abs generated by local_abs_lift when the canonizer don't
need it anymore need it anymore
""" """
...@@ -4968,6 +5070,8 @@ def attempt_distribution(factor, num, denum, out_type): ...@@ -4968,6 +5070,8 @@ def attempt_distribution(factor, num, denum, out_type):
@gof.local_optimizer([T.mul, T.true_div, T.inv]) @gof.local_optimizer([T.mul, T.true_div, T.inv])
def local_greedy_distributor(node): def local_greedy_distributor(node):
""" """
Optimize by reducing the number of multiplications and/or divisions.
This optimization tries to apply distributivity of multiplication This optimization tries to apply distributivity of multiplication
to addition in order to reduce the number of multiplications to addition in order to reduce the number of multiplications
and/or divisions that must be done. The algorithm weighs division and/or divisions that must be done. The algorithm weighs division
...@@ -4985,6 +5089,7 @@ def local_greedy_distributor(node): ...@@ -4985,6 +5089,7 @@ def local_greedy_distributor(node):
This optimization aims to reduce computational cost. It may also This optimization aims to reduce computational cost. It may also
increase numerical stability, e.g. when x and/or y tend to 0 in increase numerical stability, e.g. when x and/or y tend to 0 in
example 1. example 1.
""" """
out = node.outputs[0] out = node.outputs[0]
...@@ -5083,7 +5188,12 @@ def constant_folding(node): ...@@ -5083,7 +5188,12 @@ def constant_folding(node):
def _is_1(expr): def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1 """
Returns
-------
bool
True iff expr is a constant close to 1.
""" """
try: try:
v = get_scalar_constant_value(expr) v = get_scalar_constant_value(expr)
...@@ -5093,7 +5203,12 @@ def _is_1(expr): ...@@ -5093,7 +5203,12 @@ def _is_1(expr):
def _is_minus1(expr): def _is_minus1(expr):
"""rtype bool. True iff expr is a constant close to -1 """
Returns
-------
bool
True iff expr is a constant close to -1.
""" """
try: try:
v = get_scalar_constant_value(expr) v = get_scalar_constant_value(expr)
...@@ -5103,13 +5218,13 @@ def _is_minus1(expr): ...@@ -5103,13 +5218,13 @@ def _is_minus1(expr):
def get_clients(node): def get_clients(node):
"Used by erf/erfc opt to track less frequent op" """Used by erf/erfc opt to track less frequent op."""
return [c for c, i in node.outputs[0].clients return [c for c, i in node.outputs[0].clients
if c != "output"] if c != "output"]
def get_clients2(node): def get_clients2(node):
"Used by erf/erfc opt to track less frequent op" """Used by erf/erfc opt to track less frequent op."""
l = [] l = []
for c, i in node.outputs[0].clients: for c, i in node.outputs[0].clients:
if c != "output": if c != "output":
...@@ -5622,9 +5737,12 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, ...@@ -5622,9 +5737,12 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
""" """
We parametrize it to make it work for Elemwise and GpuElemwise op. We parametrize it to make it work for Elemwise and GpuElemwise op.
:param OP: GpuElemwise or Elemwise class (the one that we want to fuse) Parameters
----------
:param max_input_fct: a function that returns the maximum number of inputs OP
GpuElemwise or Elemwise class (the one that we want to fuse)
max_input_fct
A function that returns the maximum number of inputs
that this elemwise can take (useful for GpuElemwise). that this elemwise can take (useful for GpuElemwise).
GPU kernel currently has a limit of 256 bytes for GPU kernel currently has a limit of 256 bytes for
the size of all parameters passed to it. As currently the size of all parameters passed to it. As currently
...@@ -5634,6 +5752,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, ...@@ -5634,6 +5752,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
On the CPU we limit to 32 input variables On the CPU we limit to 32 input variables
since that is the maximum numpy support. since that is the maximum numpy support.
""" """
if maker is None: if maker is None:
def maker(node, scalar_op): def maker(node, scalar_op):
...@@ -5647,6 +5766,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, ...@@ -5647,6 +5766,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
For mixed dtype, we let the Composite op do the cast. It lets the C For mixed dtype, we let the Composite op do the cast. It lets the C
compiler do the cast. compiler do the cast.
The number of dimensions is validated at call time by theano itself. The number of dimensions is validated at call time by theano itself.
""" """
# META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!! # META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!!
# TODO: use broadcast flag? # TODO: use broadcast flag?
...@@ -5862,7 +5982,7 @@ local_elemwise_fusion = local_elemwise_fusion_op(T.Elemwise, ...@@ -5862,7 +5982,7 @@ local_elemwise_fusion = local_elemwise_fusion_op(T.Elemwise,
class FusionOptimizer(Optimizer): class FusionOptimizer(Optimizer):
"""Graph optimizer for Fusion of elemwise operations""" """Graph optimizer for Fusion of elemwise operations."""
def __init__(self, local_optimizer): def __init__(self, local_optimizer):
Optimizer.__init__(self) Optimizer.__init__(self)
self.optimizer = local_optimizer self.optimizer = local_optimizer
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论