提交 7e822dbb authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/tensor/nnet/sigm.py

上级 430561ad
"""Ops and optimizations: sigmoid, softplus """
Ops and optimizations: sigmoid, softplus.
These functions implement special cases of exp and log to improve numerical
stability.
These functions implement special cases of exp and log to improve numerical stability.
""" """
from __future__ import print_function from __future__ import print_function
...@@ -25,6 +28,7 @@ from theano.tensor import elemwise, opt, NotScalarConstantError ...@@ -25,6 +28,7 @@ from theano.tensor import elemwise, opt, NotScalarConstantError
class ScalarSigmoid(scalar.UnaryScalarOp): class ScalarSigmoid(scalar.UnaryScalarOp):
""" """
This is just speed opt. Not for stability. This is just speed opt. Not for stability.
""" """
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
...@@ -126,7 +130,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp): ...@@ -126,7 +130,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp):
@staticmethod @staticmethod
def gen_graph(): def gen_graph():
""" """
This method was used to generate the graph: sigmoid_prec.png in the doc This method was used to generate the graph: sigmoid_prec.png in the doc.
""" """
data = numpy.arange(-15, 15, .1) data = numpy.arange(-15, 15, .1)
val = 1 / (1 + numpy.exp(-data)) val = 1 / (1 + numpy.exp(-data))
...@@ -173,6 +178,7 @@ pprint.assign(sigmoid, printing.FunctionPrinter('sigmoid')) ...@@ -173,6 +178,7 @@ pprint.assign(sigmoid, printing.FunctionPrinter('sigmoid'))
class UltraFastScalarSigmoid(scalar.UnaryScalarOp): class UltraFastScalarSigmoid(scalar.UnaryScalarOp):
""" """
This is just speed opt. Not for stability. This is just speed opt. Not for stability.
""" """
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
...@@ -245,7 +251,7 @@ def local_ultra_fast_sigmoid(node): ...@@ -245,7 +251,7 @@ def local_ultra_fast_sigmoid(node):
When enabled, change all sigmoid to ultra_fast_sigmoid. When enabled, change all sigmoid to ultra_fast_sigmoid.
For example do mode.including('local_ultra_fast_sigmoid') For example do mode.including('local_ultra_fast_sigmoid')
or use the Theano flag optimizer_including=local_ultra_fast_sigmoid or use the Theano flag optimizer_including=local_ultra_fast_sigmoid.
This speeds up the sigmoid op by using an approximation. This speeds up the sigmoid op by using an approximation.
...@@ -269,11 +275,12 @@ theano.compile.optdb['uncanonicalize'].register("local_ultra_fast_sigmoid", ...@@ -269,11 +275,12 @@ theano.compile.optdb['uncanonicalize'].register("local_ultra_fast_sigmoid",
def hard_sigmoid(x): def hard_sigmoid(x):
"""An approximation of sigmoid. """
An approximation of sigmoid.
More approximate and faster than ultra_fast_sigmoid. More approximate and faster than ultra_fast_sigmoid.
Approx in 3 parts: 0, scaled linear, 1 Approx in 3 parts: 0, scaled linear, 1.
Removing the slope and shift does not make it faster. Removing the slope and shift does not make it faster.
...@@ -375,7 +382,12 @@ logsigm_to_softplus = gof.PatternSub( ...@@ -375,7 +382,12 @@ logsigm_to_softplus = gof.PatternSub(
def _is_1(expr): def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1 """
Returns
-------
bool
True iff expr is a constant close to 1.
""" """
try: try:
v = opt.get_scalar_constant_value(expr) v = opt.get_scalar_constant_value(expr)
...@@ -405,8 +417,11 @@ opt.register_stabilize(log1pexp_to_softplus, name='log1pexp_to_softplus') ...@@ -405,8 +417,11 @@ opt.register_stabilize(log1pexp_to_softplus, name='log1pexp_to_softplus')
def is_1pexp(t): def is_1pexp(t):
""" """
Returns
-------
If 't' is of the form (1+exp(x)), return (False, x). If 't' is of the form (1+exp(x)), return (False, x).
Else return None. Else return None.
""" """
if t.owner and t.owner.op == tensor.add: if t.owner and t.owner.op == tensor.add:
scalars, scalar_inputs, nonconsts = \ scalars, scalar_inputs, nonconsts = \
...@@ -449,11 +464,17 @@ def is_exp(var): ...@@ -449,11 +464,17 @@ def is_exp(var):
""" """
Match a variable with either of the `exp(x)` or `-exp(x)` patterns. Match a variable with either of the `exp(x)` or `-exp(x)` patterns.
:param var: The Variable to analyze. Parameters
----------
var
The Variable to analyze.
Returns
-------
A pair (b, x) with `b` a boolean set to True if `var` is of the
form `-exp(x)` and False if `var` is of the form `exp(x)`. If `var`
cannot be cast into either form, then return `None`.
:return: A pair (b, x) with `b` a boolean set to True if `var` is of the
form `-exp(x)` and False if `var` is of the form `exp(x)`. If `var` cannot
be cast into either form, then return `None`.
""" """
neg = False neg = False
neg_info = is_neg(var) neg_info = is_neg(var)
...@@ -468,10 +489,16 @@ def is_mul(var): ...@@ -468,10 +489,16 @@ def is_mul(var):
""" """
Match a variable with `x * y * z * ...`. Match a variable with `x * y * z * ...`.
:param var: The Variable to analyze. Parameters
----------
var
The Variable to analyze.
:return: A list [x, y, z, ...] if `var` is of the form `x * y * z * ...`, Returns
-------
A list [x, y, z, ...] if `var` is of the form `x * y * z * ...`,
or None if `var` cannot be cast into this form. or None if `var` cannot be cast into this form.
""" """
if var.owner and var.owner.op == tensor.mul: if var.owner and var.owner.op == tensor.mul:
return var.owner.inputs return var.owner.inputs
...@@ -504,9 +531,15 @@ def is_neg(var): ...@@ -504,9 +531,15 @@ def is_neg(var):
""" """
Match a variable with the `-x` pattern. Match a variable with the `-x` pattern.
:param var: The Variable to analyze. Parameters
----------
var
The Variable to analyze.
Returns
-------
`x` if `var` is of the form `-x`, or None otherwise.
:return: `x` if `var` is of the form `-x`, or None otherwise.
""" """
apply = var.owner apply = var.owner
if not apply: if not apply:
...@@ -538,8 +571,10 @@ def is_neg(var): ...@@ -538,8 +571,10 @@ def is_neg(var):
@opt.register_stabilize @opt.register_stabilize
@gof.local_optimizer([tensor.true_div]) @gof.local_optimizer([tensor.true_div])
def local_exp_over_1_plus_exp(node): def local_exp_over_1_plus_exp(node):
"""exp(x)/(1+exp(x)) -> sigm(x) """
exp(x)/(1+exp(x)) -> sigm(x)
c/(1+exp(x)) -> c*sigm(-x) c/(1+exp(x)) -> c*sigm(-x)
""" """
# this optimization should be done for numerical stability # this optimization should be done for numerical stability
# so we don't care to check client counts # so we don't care to check client counts
...@@ -585,20 +620,27 @@ def parse_mul_tree(root): ...@@ -585,20 +620,27 @@ def parse_mul_tree(root):
""" """
Parse a tree of multiplications starting at the given root. Parse a tree of multiplications starting at the given root.
:param root: The variable at the root of the tree. Parameters
----------
root
The variable at the root of the tree.
:return: A tree where each non-leaf node corresponds to a multiplication Returns
in the computation of `root`, represented by the list of its inputs. Each -------
input is a pair [n, x] with `n` a boolean value indicating whether A tree where each non-leaf node corresponds to a multiplication
in the computation of `root`, represented by the list of its inputs.
Each input is a pair [n, x] with `n` a boolean value indicating whether
sub-tree `x` should be negated. sub-tree `x` should be negated.
Examples: Examples
--------
x * y -> [False, [[False, x], [False, y]]] x * y -> [False, [[False, x], [False, y]]]
-(x * y) -> [True, [[False, x], [False, y]]] -(x * y) -> [True, [[False, x], [False, y]]]
-x * y -> [False, [[True, x], [False, y]]] -x * y -> [False, [[True, x], [False, y]]]
-x -> [True, x] -x -> [True, x]
(x * y) * -z -> [False, [[False, [[False, x], [False, y]]], (x * y) * -z -> [False, [[False, [[False, x], [False, y]]],
[True, z]]] [True, z]]]
""" """
# Is it a multiplication? # Is it a multiplication?
mul_info = is_mul(root) mul_info = is_mul(root)
...@@ -619,29 +661,35 @@ def parse_mul_tree(root): ...@@ -619,29 +661,35 @@ def parse_mul_tree(root):
def replace_leaf(arg, leaves, new_leaves, op, neg): def replace_leaf(arg, leaves, new_leaves, op, neg):
""" """
Attempts to replace a leaf of a multiplication tree. Attempt to replace a leaf of a multiplication tree.
We search for a leaf in `leaves` whose argument is `arg`, and if we find We search for a leaf in `leaves` whose argument is `arg`, and if we find
one, we remove it from `leaves` and add to `new_leaves` a leaf with one, we remove it from `leaves` and add to `new_leaves` a leaf with
argument `arg` and variable `op(arg)`. argument `arg` and variable `op(arg)`.
:param arg: The argument of the leaf we are looking for. Parameters
----------
:param leaves: List of leaves to look into. Each leaf should be a pair arg
The argument of the leaf we are looking for.
leaves
List of leaves to look into. Each leaf should be a pair
(x, l) with `x` the argument of the Op found in the leaf, and `l` the (x, l) with `x` the argument of the Op found in the leaf, and `l` the
actual leaf as found in a multiplication tree output by `parse_mul_tree` actual leaf as found in a multiplication tree output by `parse_mul_tree`
(i.e. a pair [boolean, variable]). (i.e. a pair [boolean, variable]).
new_leaves
:param new_leaves: If a replacement occurred, then the leaf is removed from If a replacement occurred, then the leaf is removed from `leaves`
`leaves` and added to the list `new_leaves` (after being modified by `op`). and added to the list `new_leaves` (after being modified by `op`).
op
:param op: A function that, when applied to `arg`, returns the Variable A function that, when applied to `arg`, returns the Variable
we want to replace the original leaf variable with. we want to replace the original leaf variable with.
neg : bool
:param neg: If True, then the boolean value associated to the leaf should If True, then the boolean value associated to the leaf should
be swapped. If False, then this value should remain unchanged. be swapped. If False, then this value should remain unchanged.
:return: True if a replacement occurred, or False otherwise. Returns
-------
True if a replacement occurred, or False otherwise.
""" """
for idx, x in enumerate(leaves): for idx, x in enumerate(leaves):
if x[0] == arg: if x[0] == arg:
...@@ -657,12 +705,18 @@ def simplify_mul(tree): ...@@ -657,12 +705,18 @@ def simplify_mul(tree):
""" """
Simplify a multiplication tree. Simplify a multiplication tree.
:param tree: A multiplication tree (as output by `parse_mul_tree`). Parameters
----------
tree
A multiplication tree (as output by `parse_mul_tree`).
:return: A multiplication tree computing the same output as `tree` but Returns
without useless multiplications by 1 nor -1 (identified by leaves of the -------
form [False, None] or [True, None] respectively). Useless multiplications A multiplication tree computing the same output as `tree` but without
useless multiplications by 1 nor -1 (identified by leaves of the form
[False, None] or [True, None] respectively). Useless multiplications
(with less than two inputs) are also removed from the tree. (with less than two inputs) are also removed from the tree.
""" """
neg, inputs = tree neg, inputs = tree
if isinstance(inputs, list): if isinstance(inputs, list):
...@@ -694,12 +748,17 @@ def compute_mul(tree): ...@@ -694,12 +748,17 @@ def compute_mul(tree):
Compute the Variable that is the output of a multiplication tree. Compute the Variable that is the output of a multiplication tree.
This is the inverse of the operation performed by `parse_mul_tree`, i.e. This is the inverse of the operation performed by `parse_mul_tree`, i.e.
compute_mul(parse_mul_tree(tree)) == tree compute_mul(parse_mul_tree(tree)) == tree.
:param tree: A multiplication tree (as output by `parse_mul_tree`). Parameters
----------
tree
A multiplication tree (as output by `parse_mul_tree`).
Returns
-------
A Variable that computes the multiplication represented by the tree.
:return: A Variable that computes the multiplication represented by the
tree.
""" """
neg, inputs = tree neg, inputs = tree
if inputs is None: if inputs is None:
...@@ -727,32 +786,38 @@ def perform_sigm_times_exp(tree, exp_x=None, exp_minus_x=None, sigm_x=None, ...@@ -727,32 +786,38 @@ def perform_sigm_times_exp(tree, exp_x=None, exp_minus_x=None, sigm_x=None,
by replacing matching pairs (exp, sigmoid) with the desired optimized by replacing matching pairs (exp, sigmoid) with the desired optimized
version. version.
:param tree: The sub-tree to operate on. Parameters
----------
:exp_x: List of arguments x so that `exp(x)` exists somewhere in the whole tree
The sub-tree to operate on.
exp_x
List of arguments x so that `exp(x)` exists somewhere in the whole
multiplication tree. Each argument is a pair (x, leaf) with `x` the multiplication tree. Each argument is a pair (x, leaf) with `x` the
argument of the exponential, and `leaf` the corresponding leaf in the argument of the exponential, and `leaf` the corresponding leaf in the
multiplication tree (of the form [n, exp(x)] -- see `parse_mul_tree`). multiplication tree (of the form [n, exp(x)] -- see `parse_mul_tree`).
If None, this argument is initialized to an empty list. If None, this argument is initialized to an empty list.
exp_minus_x
Similar to `exp_x`, but for `exp(-x)`.
sigm_x
Similar to `exp_x`, but for `sigmoid(x)`.
sigm_minus_x
Similar to `exp_x`, but for `sigmoid(-x)`.
parent
Parent of `tree` (None if `tree` is the global root).
child_idx
Index of `tree` in its parent's inputs (None if `tree` is the global
root).
full_tree
The global multiplication tree (should not be set except by recursive
calls to this function). Used for debugging only.
Returns
-------
bool
True if a modification was performed somewhere in the whole multiplication
tree, or False otherwise.
:param exp_minus_x: Similar to `exp_x`, but for `exp(-x)`.
:param sigm_x: Similar to `exp_x`, but for `sigmoid(x)`.
:param sigm_minus_x: Similar to `exp_x`, but for `sigmoid(-x)`.
:param parent: Parent of `tree` (None if `tree` is the global root).
:param child_idx: Index of `tree` in its parent's inputs (None if `tree` is
the global root).
:param full_tree: The global multiplication tree (should not be set except
by recursive calls to this function). Used for debugging only.
:return: True if a modification was performed somewhere in the whole
multiplication tree, or False otherwise.
""" """
if exp_x is None: if exp_x is None:
exp_x = [] exp_x = []
if exp_minus_x is None: if exp_minus_x is None:
...@@ -836,6 +901,7 @@ def local_sigm_times_exp(node): ...@@ -836,6 +901,7 @@ def local_sigm_times_exp(node):
""" """
exp(x) * sigm(-x) -> sigm(x) exp(x) * sigm(-x) -> sigm(x)
exp(-x) * sigm(x) -> sigm(-x) exp(-x) * sigm(x) -> sigm(-x)
""" """
# Bail early if it is not a multiplication. # Bail early if it is not a multiplication.
if node.op != tensor.mul: if node.op != tensor.mul:
...@@ -859,6 +925,7 @@ def local_sigm_times_exp(node): ...@@ -859,6 +925,7 @@ def local_sigm_times_exp(node):
def local_inv_1_plus_exp(node): def local_inv_1_plus_exp(node):
""" """
1/(1+exp(x)) -> sigm(-x) 1/(1+exp(x)) -> sigm(-x)
""" """
# this optimization should be done for numerical stability # this optimization should be done for numerical stability
# so we don't care to check client counts # so we don't care to check client counts
...@@ -883,6 +950,7 @@ def local_inv_1_plus_exp(node): ...@@ -883,6 +950,7 @@ def local_inv_1_plus_exp(node):
def local_1msigmoid(node): def local_1msigmoid(node):
""" """
1-sigm(x) -> sigm(-x) 1-sigm(x) -> sigm(-x)
""" """
if node.op == tensor.sub: if node.op == tensor.sub:
sub_l, sub_r = node.inputs sub_l, sub_r = node.inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论