提交 7e822dbb authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/tensor/nnet/sigm.py

上级 430561ad
"""Ops and optimizations: sigmoid, softplus """
Ops and optimizations: sigmoid, softplus.
These functions implement special cases of exp and log to improve numerical
stability.
These functions implement special cases of exp and log to improve numerical stability.
""" """
from __future__ import print_function from __future__ import print_function
...@@ -25,6 +28,7 @@ from theano.tensor import elemwise, opt, NotScalarConstantError ...@@ -25,6 +28,7 @@ from theano.tensor import elemwise, opt, NotScalarConstantError
class ScalarSigmoid(scalar.UnaryScalarOp): class ScalarSigmoid(scalar.UnaryScalarOp):
""" """
This is just speed opt. Not for stability. This is just speed opt. Not for stability.
""" """
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
...@@ -126,7 +130,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp): ...@@ -126,7 +130,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp):
@staticmethod @staticmethod
def gen_graph(): def gen_graph():
""" """
This method was used to generate the graph: sigmoid_prec.png in the doc This method was used to generate the graph: sigmoid_prec.png in the doc.
""" """
data = numpy.arange(-15, 15, .1) data = numpy.arange(-15, 15, .1)
val = 1 / (1 + numpy.exp(-data)) val = 1 / (1 + numpy.exp(-data))
...@@ -173,6 +178,7 @@ pprint.assign(sigmoid, printing.FunctionPrinter('sigmoid')) ...@@ -173,6 +178,7 @@ pprint.assign(sigmoid, printing.FunctionPrinter('sigmoid'))
class UltraFastScalarSigmoid(scalar.UnaryScalarOp): class UltraFastScalarSigmoid(scalar.UnaryScalarOp):
""" """
This is just speed opt. Not for stability. This is just speed opt. Not for stability.
""" """
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
...@@ -245,7 +251,7 @@ def local_ultra_fast_sigmoid(node): ...@@ -245,7 +251,7 @@ def local_ultra_fast_sigmoid(node):
When enabled, change all sigmoid to ultra_fast_sigmoid. When enabled, change all sigmoid to ultra_fast_sigmoid.
For example do mode.including('local_ultra_fast_sigmoid') For example do mode.including('local_ultra_fast_sigmoid')
or use the Theano flag optimizer_including=local_ultra_fast_sigmoid or use the Theano flag optimizer_including=local_ultra_fast_sigmoid.
This speeds up the sigmoid op by using an approximation. This speeds up the sigmoid op by using an approximation.
...@@ -269,11 +275,12 @@ theano.compile.optdb['uncanonicalize'].register("local_ultra_fast_sigmoid", ...@@ -269,11 +275,12 @@ theano.compile.optdb['uncanonicalize'].register("local_ultra_fast_sigmoid",
def hard_sigmoid(x): def hard_sigmoid(x):
"""An approximation of sigmoid. """
An approximation of sigmoid.
More approximate and faster than ultra_fast_sigmoid. More approximate and faster than ultra_fast_sigmoid.
Approx in 3 parts: 0, scaled linear, 1 Approx in 3 parts: 0, scaled linear, 1.
Removing the slope and shift does not make it faster. Removing the slope and shift does not make it faster.
...@@ -375,7 +382,12 @@ logsigm_to_softplus = gof.PatternSub( ...@@ -375,7 +382,12 @@ logsigm_to_softplus = gof.PatternSub(
def _is_1(expr): def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1 """
Returns
-------
bool
True iff expr is a constant close to 1.
""" """
try: try:
v = opt.get_scalar_constant_value(expr) v = opt.get_scalar_constant_value(expr)
...@@ -405,8 +417,11 @@ opt.register_stabilize(log1pexp_to_softplus, name='log1pexp_to_softplus') ...@@ -405,8 +417,11 @@ opt.register_stabilize(log1pexp_to_softplus, name='log1pexp_to_softplus')
def is_1pexp(t): def is_1pexp(t):
""" """
If 't' is of the form (1+exp(x)), return (False, x). Returns
Else return None. -------
If 't' is of the form (1+exp(x)), return (False, x).
Else return None.
""" """
if t.owner and t.owner.op == tensor.add: if t.owner and t.owner.op == tensor.add:
scalars, scalar_inputs, nonconsts = \ scalars, scalar_inputs, nonconsts = \
...@@ -449,11 +464,17 @@ def is_exp(var): ...@@ -449,11 +464,17 @@ def is_exp(var):
""" """
Match a variable with either of the `exp(x)` or `-exp(x)` patterns. Match a variable with either of the `exp(x)` or `-exp(x)` patterns.
:param var: The Variable to analyze. Parameters
----------
var
The Variable to analyze.
Returns
-------
A pair (b, x) with `b` a boolean set to True if `var` is of the
form `-exp(x)` and False if `var` is of the form `exp(x)`. If `var`
cannot be cast into either form, then return `None`.
:return: A pair (b, x) with `b` a boolean set to True if `var` is of the
form `-exp(x)` and False if `var` is of the form `exp(x)`. If `var` cannot
be cast into either form, then return `None`.
""" """
neg = False neg = False
neg_info = is_neg(var) neg_info = is_neg(var)
...@@ -468,10 +489,16 @@ def is_mul(var): ...@@ -468,10 +489,16 @@ def is_mul(var):
""" """
Match a variable with `x * y * z * ...`. Match a variable with `x * y * z * ...`.
:param var: The Variable to analyze. Parameters
----------
var
The Variable to analyze.
Returns
-------
A list [x, y, z, ...] if `var` is of the form `x * y * z * ...`,
or None if `var` cannot be cast into this form.
:return: A list [x, y, z, ...] if `var` is of the form `x * y * z * ...`,
or None if `var` cannot be cast into this form.
""" """
if var.owner and var.owner.op == tensor.mul: if var.owner and var.owner.op == tensor.mul:
return var.owner.inputs return var.owner.inputs
...@@ -504,9 +531,15 @@ def is_neg(var): ...@@ -504,9 +531,15 @@ def is_neg(var):
""" """
Match a variable with the `-x` pattern. Match a variable with the `-x` pattern.
:param var: The Variable to analyze. Parameters
----------
var
The Variable to analyze.
Returns
-------
`x` if `var` is of the form `-x`, or None otherwise.
:return: `x` if `var` is of the form `-x`, or None otherwise.
""" """
apply = var.owner apply = var.owner
if not apply: if not apply:
...@@ -538,8 +571,10 @@ def is_neg(var): ...@@ -538,8 +571,10 @@ def is_neg(var):
@opt.register_stabilize @opt.register_stabilize
@gof.local_optimizer([tensor.true_div]) @gof.local_optimizer([tensor.true_div])
def local_exp_over_1_plus_exp(node): def local_exp_over_1_plus_exp(node):
"""exp(x)/(1+exp(x)) -> sigm(x) """
exp(x)/(1+exp(x)) -> sigm(x)
c/(1+exp(x)) -> c*sigm(-x) c/(1+exp(x)) -> c*sigm(-x)
""" """
# this optimization should be done for numerical stability # this optimization should be done for numerical stability
# so we don't care to check client counts # so we don't care to check client counts
...@@ -585,20 +620,27 @@ def parse_mul_tree(root): ...@@ -585,20 +620,27 @@ def parse_mul_tree(root):
""" """
Parse a tree of multiplications starting at the given root. Parse a tree of multiplications starting at the given root.
:param root: The variable at the root of the tree. Parameters
----------
root
The variable at the root of the tree.
:return: A tree where each non-leaf node corresponds to a multiplication Returns
in the computation of `root`, represented by the list of its inputs. Each -------
input is a pair [n, x] with `n` a boolean value indicating whether A tree where each non-leaf node corresponds to a multiplication
sub-tree `x` should be negated. in the computation of `root`, represented by the list of its inputs.
Each input is a pair [n, x] with `n` a boolean value indicating whether
sub-tree `x` should be negated.
Examples: Examples
--------
x * y -> [False, [[False, x], [False, y]]] x * y -> [False, [[False, x], [False, y]]]
-(x * y) -> [True, [[False, x], [False, y]]] -(x * y) -> [True, [[False, x], [False, y]]]
-x * y -> [False, [[True, x], [False, y]]] -x * y -> [False, [[True, x], [False, y]]]
-x -> [True, x] -x -> [True, x]
(x * y) * -z -> [False, [[False, [[False, x], [False, y]]], (x * y) * -z -> [False, [[False, [[False, x], [False, y]]],
[True, z]]] [True, z]]]
""" """
# Is it a multiplication? # Is it a multiplication?
mul_info = is_mul(root) mul_info = is_mul(root)
...@@ -619,29 +661,35 @@ def parse_mul_tree(root): ...@@ -619,29 +661,35 @@ def parse_mul_tree(root):
def replace_leaf(arg, leaves, new_leaves, op, neg): def replace_leaf(arg, leaves, new_leaves, op, neg):
""" """
Attempts to replace a leaf of a multiplication tree. Attempt to replace a leaf of a multiplication tree.
We search for a leaf in `leaves` whose argument is `arg`, and if we find We search for a leaf in `leaves` whose argument is `arg`, and if we find
one, we remove it from `leaves` and add to `new_leaves` a leaf with one, we remove it from `leaves` and add to `new_leaves` a leaf with
argument `arg` and variable `op(arg)`. argument `arg` and variable `op(arg)`.
:param arg: The argument of the leaf we are looking for. Parameters
----------
:param leaves: List of leaves to look into. Each leaf should be a pair arg
(x, l) with `x` the argument of the Op found in the leaf, and `l` the The argument of the leaf we are looking for.
actual leaf as found in a multiplication tree output by `parse_mul_tree` leaves
(i.e. a pair [boolean, variable]). List of leaves to look into. Each leaf should be a pair
(x, l) with `x` the argument of the Op found in the leaf, and `l` the
actual leaf as found in a multiplication tree output by `parse_mul_tree`
(i.e. a pair [boolean, variable]).
new_leaves
If a replacement occurred, then the leaf is removed from `leaves`
and added to the list `new_leaves` (after being modified by `op`).
op
A function that, when applied to `arg`, returns the Variable
we want to replace the original leaf variable with.
neg : bool
If True, then the boolean value associated to the leaf should
be swapped. If False, then this value should remain unchanged.
Returns
-------
True if a replacement occurred, or False otherwise.
:param new_leaves: If a replacement occurred, then the leaf is removed from
`leaves` and added to the list `new_leaves` (after being modified by `op`).
:param op: A function that, when applied to `arg`, returns the Variable
we want to replace the original leaf variable with.
:param neg: If True, then the boolean value associated to the leaf should
be swapped. If False, then this value should remain unchanged.
:return: True if a replacement occurred, or False otherwise.
""" """
for idx, x in enumerate(leaves): for idx, x in enumerate(leaves):
if x[0] == arg: if x[0] == arg:
...@@ -657,12 +705,18 @@ def simplify_mul(tree): ...@@ -657,12 +705,18 @@ def simplify_mul(tree):
""" """
Simplify a multiplication tree. Simplify a multiplication tree.
:param tree: A multiplication tree (as output by `parse_mul_tree`). Parameters
----------
tree
A multiplication tree (as output by `parse_mul_tree`).
Returns
-------
A multiplication tree computing the same output as `tree` but without
useless multiplications by 1 nor -1 (identified by leaves of the form
[False, None] or [True, None] respectively). Useless multiplications
(with less than two inputs) are also removed from the tree.
:return: A multiplication tree computing the same output as `tree` but
without useless multiplications by 1 nor -1 (identified by leaves of the
form [False, None] or [True, None] respectively). Useless multiplications
(with less than two inputs) are also removed from the tree.
""" """
neg, inputs = tree neg, inputs = tree
if isinstance(inputs, list): if isinstance(inputs, list):
...@@ -694,12 +748,17 @@ def compute_mul(tree): ...@@ -694,12 +748,17 @@ def compute_mul(tree):
Compute the Variable that is the output of a multiplication tree. Compute the Variable that is the output of a multiplication tree.
This is the inverse of the operation performed by `parse_mul_tree`, i.e. This is the inverse of the operation performed by `parse_mul_tree`, i.e.
compute_mul(parse_mul_tree(tree)) == tree compute_mul(parse_mul_tree(tree)) == tree.
:param tree: A multiplication tree (as output by `parse_mul_tree`). Parameters
----------
tree
A multiplication tree (as output by `parse_mul_tree`).
Returns
-------
A Variable that computes the multiplication represented by the tree.
:return: A Variable that computes the multiplication represented by the
tree.
""" """
neg, inputs = tree neg, inputs = tree
if inputs is None: if inputs is None:
...@@ -727,32 +786,38 @@ def perform_sigm_times_exp(tree, exp_x=None, exp_minus_x=None, sigm_x=None, ...@@ -727,32 +786,38 @@ def perform_sigm_times_exp(tree, exp_x=None, exp_minus_x=None, sigm_x=None,
by replacing matching pairs (exp, sigmoid) with the desired optimized by replacing matching pairs (exp, sigmoid) with the desired optimized
version. version.
:param tree: The sub-tree to operate on. Parameters
----------
:exp_x: List of arguments x so that `exp(x)` exists somewhere in the whole tree
multiplication tree. Each argument is a pair (x, leaf) with `x` the The sub-tree to operate on.
argument of the exponential, and `leaf` the corresponding leaf in the exp_x
multiplication tree (of the form [n, exp(x)] -- see `parse_mul_tree`). List of arguments x so that `exp(x)` exists somewhere in the whole
If None, this argument is initialized to an empty list. multiplication tree. Each argument is a pair (x, leaf) with `x` the
argument of the exponential, and `leaf` the corresponding leaf in the
:param exp_minus_x: Similar to `exp_x`, but for `exp(-x)`. multiplication tree (of the form [n, exp(x)] -- see `parse_mul_tree`).
If None, this argument is initialized to an empty list.
:param sigm_x: Similar to `exp_x`, but for `sigmoid(x)`. exp_minus_x
Similar to `exp_x`, but for `exp(-x)`.
sigm_x
Similar to `exp_x`, but for `sigmoid(x)`.
sigm_minus_x
Similar to `exp_x`, but for `sigmoid(-x)`.
parent
Parent of `tree` (None if `tree` is the global root).
child_idx
Index of `tree` in its parent's inputs (None if `tree` is the global
root).
full_tree
The global multiplication tree (should not be set except by recursive
calls to this function). Used for debugging only.
Returns
-------
bool
True if a modification was performed somewhere in the whole multiplication
tree, or False otherwise.
:param sigm_minus_x: Similar to `exp_x`, but for `sigmoid(-x)`.
:param parent: Parent of `tree` (None if `tree` is the global root).
:param child_idx: Index of `tree` in its parent's inputs (None if `tree` is
the global root).
:param full_tree: The global multiplication tree (should not be set except
by recursive calls to this function). Used for debugging only.
:return: True if a modification was performed somewhere in the whole
multiplication tree, or False otherwise.
""" """
if exp_x is None: if exp_x is None:
exp_x = [] exp_x = []
if exp_minus_x is None: if exp_minus_x is None:
...@@ -836,6 +901,7 @@ def local_sigm_times_exp(node): ...@@ -836,6 +901,7 @@ def local_sigm_times_exp(node):
""" """
exp(x) * sigm(-x) -> sigm(x) exp(x) * sigm(-x) -> sigm(x)
exp(-x) * sigm(x) -> sigm(-x) exp(-x) * sigm(x) -> sigm(-x)
""" """
# Bail early if it is not a multiplication. # Bail early if it is not a multiplication.
if node.op != tensor.mul: if node.op != tensor.mul:
...@@ -859,6 +925,7 @@ def local_sigm_times_exp(node): ...@@ -859,6 +925,7 @@ def local_sigm_times_exp(node):
def local_inv_1_plus_exp(node): def local_inv_1_plus_exp(node):
""" """
1/(1+exp(x)) -> sigm(-x) 1/(1+exp(x)) -> sigm(-x)
""" """
# this optimization should be done for numerical stability # this optimization should be done for numerical stability
# so we don't care to check client counts # so we don't care to check client counts
...@@ -883,6 +950,7 @@ def local_inv_1_plus_exp(node): ...@@ -883,6 +950,7 @@ def local_inv_1_plus_exp(node):
def local_1msigmoid(node): def local_1msigmoid(node):
""" """
1-sigm(x) -> sigm(-x) 1-sigm(x) -> sigm(-x)
""" """
if node.op == tensor.sub: if node.op == tensor.sub:
sub_l, sub_r = node.inputs sub_l, sub_r = node.inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论