提交 7e822dbb authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/tensor/nnet/sigm.py

上级 430561ad
"""Ops and optimizations: sigmoid, softplus
"""
Ops and optimizations: sigmoid, softplus.
These functions implement special cases of exp and log to improve numerical
stability.
These functions implement special cases of exp and log to improve numerical stability.
"""
from __future__ import print_function
......@@ -25,6 +28,7 @@ from theano.tensor import elemwise, opt, NotScalarConstantError
class ScalarSigmoid(scalar.UnaryScalarOp):
"""
This is just speed opt. Not for stability.
"""
@staticmethod
def st_impl(x):
......@@ -126,7 +130,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp):
@staticmethod
def gen_graph():
"""
This method was used to generate the graph: sigmoid_prec.png in the doc
This method was used to generate the graph: sigmoid_prec.png in the doc.
"""
data = numpy.arange(-15, 15, .1)
val = 1 / (1 + numpy.exp(-data))
......@@ -173,6 +178,7 @@ pprint.assign(sigmoid, printing.FunctionPrinter('sigmoid'))
class UltraFastScalarSigmoid(scalar.UnaryScalarOp):
"""
This is just speed opt. Not for stability.
"""
@staticmethod
def st_impl(x):
......@@ -245,7 +251,7 @@ def local_ultra_fast_sigmoid(node):
When enabled, change all sigmoid to ultra_fast_sigmoid.
For example do mode.including('local_ultra_fast_sigmoid')
or use the Theano flag optimizer_including=local_ultra_fast_sigmoid
or use the Theano flag optimizer_including=local_ultra_fast_sigmoid.
This speeds up the sigmoid op by using an approximation.
......@@ -269,11 +275,12 @@ theano.compile.optdb['uncanonicalize'].register("local_ultra_fast_sigmoid",
def hard_sigmoid(x):
"""An approximation of sigmoid.
"""
An approximation of sigmoid.
More approximate and faster than ultra_fast_sigmoid.
Approx in 3 parts: 0, scaled linear, 1
Approx in 3 parts: 0, scaled linear, 1.
Removing the slope and shift does not make it faster.
......@@ -375,7 +382,12 @@ logsigm_to_softplus = gof.PatternSub(
def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1
"""
Returns
-------
bool
True iff expr is a constant close to 1.
"""
try:
v = opt.get_scalar_constant_value(expr)
......@@ -405,8 +417,11 @@ opt.register_stabilize(log1pexp_to_softplus, name='log1pexp_to_softplus')
def is_1pexp(t):
"""
If 't' is of the form (1+exp(x)), return (False, x).
Else return None.
Returns
-------
If 't' is of the form (1+exp(x)), return (False, x).
Else return None.
"""
if t.owner and t.owner.op == tensor.add:
scalars, scalar_inputs, nonconsts = \
......@@ -449,11 +464,17 @@ def is_exp(var):
"""
Match a variable with either of the `exp(x)` or `-exp(x)` patterns.
:param var: The Variable to analyze.
Parameters
----------
var
The Variable to analyze.
Returns
-------
A pair (b, x) with `b` a boolean set to True if `var` is of the
form `-exp(x)` and False if `var` is of the form `exp(x)`. If `var`
cannot be cast into either form, then return `None`.
:return: A pair (b, x) with `b` a boolean set to True if `var` is of the
form `-exp(x)` and False if `var` is of the form `exp(x)`. If `var` cannot
be cast into either form, then return `None`.
"""
neg = False
neg_info = is_neg(var)
......@@ -468,10 +489,16 @@ def is_mul(var):
"""
Match a variable with `x * y * z * ...`.
:param var: The Variable to analyze.
Parameters
----------
var
The Variable to analyze.
Returns
-------
A list [x, y, z, ...] if `var` is of the form `x * y * z * ...`,
or None if `var` cannot be cast into this form.
:return: A list [x, y, z, ...] if `var` is of the form `x * y * z * ...`,
or None if `var` cannot be cast into this form.
"""
if var.owner and var.owner.op == tensor.mul:
return var.owner.inputs
......@@ -504,9 +531,15 @@ def is_neg(var):
"""
Match a variable with the `-x` pattern.
:param var: The Variable to analyze.
Parameters
----------
var
The Variable to analyze.
Returns
-------
`x` if `var` is of the form `-x`, or None otherwise.
:return: `x` if `var` is of the form `-x`, or None otherwise.
"""
apply = var.owner
if not apply:
......@@ -538,8 +571,10 @@ def is_neg(var):
@opt.register_stabilize
@gof.local_optimizer([tensor.true_div])
def local_exp_over_1_plus_exp(node):
"""exp(x)/(1+exp(x)) -> sigm(x)
"""
exp(x)/(1+exp(x)) -> sigm(x)
c/(1+exp(x)) -> c*sigm(-x)
"""
# this optimization should be done for numerical stability
# so we don't care to check client counts
......@@ -585,20 +620,27 @@ def parse_mul_tree(root):
"""
Parse a tree of multiplications starting at the given root.
:param root: The variable at the root of the tree.
Parameters
----------
root
The variable at the root of the tree.
:return: A tree where each non-leaf node corresponds to a multiplication
in the computation of `root`, represented by the list of its inputs. Each
input is a pair [n, x] with `n` a boolean value indicating whether
sub-tree `x` should be negated.
Returns
-------
A tree where each non-leaf node corresponds to a multiplication
in the computation of `root`, represented by the list of its inputs.
Each input is a pair [n, x] with `n` a boolean value indicating whether
sub-tree `x` should be negated.
Examples:
Examples
--------
x * y -> [False, [[False, x], [False, y]]]
-(x * y) -> [True, [[False, x], [False, y]]]
-x * y -> [False, [[True, x], [False, y]]]
-x -> [True, x]
(x * y) * -z -> [False, [[False, [[False, x], [False, y]]],
[True, z]]]
"""
# Is it a multiplication?
mul_info = is_mul(root)
......@@ -619,29 +661,35 @@ def parse_mul_tree(root):
def replace_leaf(arg, leaves, new_leaves, op, neg):
"""
Attempts to replace a leaf of a multiplication tree.
Attempt to replace a leaf of a multiplication tree.
We search for a leaf in `leaves` whose argument is `arg`, and if we find
one, we remove it from `leaves` and add to `new_leaves` a leaf with
argument `arg` and variable `op(arg)`.
:param arg: The argument of the leaf we are looking for.
:param leaves: List of leaves to look into. Each leaf should be a pair
(x, l) with `x` the argument of the Op found in the leaf, and `l` the
actual leaf as found in a multiplication tree output by `parse_mul_tree`
(i.e. a pair [boolean, variable]).
Parameters
----------
arg
The argument of the leaf we are looking for.
leaves
List of leaves to look into. Each leaf should be a pair
(x, l) with `x` the argument of the Op found in the leaf, and `l` the
actual leaf as found in a multiplication tree output by `parse_mul_tree`
(i.e. a pair [boolean, variable]).
new_leaves
If a replacement occurred, then the leaf is removed from `leaves`
and added to the list `new_leaves` (after being modified by `op`).
op
A function that, when applied to `arg`, returns the Variable
we want to replace the original leaf variable with.
neg : bool
If True, then the boolean value associated to the leaf should
be swapped. If False, then this value should remain unchanged.
Returns
-------
True if a replacement occurred, or False otherwise.
:param new_leaves: If a replacement occurred, then the leaf is removed from
`leaves` and added to the list `new_leaves` (after being modified by `op`).
:param op: A function that, when applied to `arg`, returns the Variable
we want to replace the original leaf variable with.
:param neg: If True, then the boolean value associated to the leaf should
be swapped. If False, then this value should remain unchanged.
:return: True if a replacement occurred, or False otherwise.
"""
for idx, x in enumerate(leaves):
if x[0] == arg:
......@@ -657,12 +705,18 @@ def simplify_mul(tree):
"""
Simplify a multiplication tree.
:param tree: A multiplication tree (as output by `parse_mul_tree`).
Parameters
----------
tree
A multiplication tree (as output by `parse_mul_tree`).
Returns
-------
A multiplication tree computing the same output as `tree` but without
useless multiplications by 1 nor -1 (identified by leaves of the form
[False, None] or [True, None] respectively). Useless multiplications
(with less than two inputs) are also removed from the tree.
:return: A multiplication tree computing the same output as `tree` but
without useless multiplications by 1 nor -1 (identified by leaves of the
form [False, None] or [True, None] respectively). Useless multiplications
(with less than two inputs) are also removed from the tree.
"""
neg, inputs = tree
if isinstance(inputs, list):
......@@ -694,12 +748,17 @@ def compute_mul(tree):
Compute the Variable that is the output of a multiplication tree.
This is the inverse of the operation performed by `parse_mul_tree`, i.e.
compute_mul(parse_mul_tree(tree)) == tree
compute_mul(parse_mul_tree(tree)) == tree.
:param tree: A multiplication tree (as output by `parse_mul_tree`).
Parameters
----------
tree
A multiplication tree (as output by `parse_mul_tree`).
Returns
-------
A Variable that computes the multiplication represented by the tree.
:return: A Variable that computes the multiplication represented by the
tree.
"""
neg, inputs = tree
if inputs is None:
......@@ -727,32 +786,38 @@ def perform_sigm_times_exp(tree, exp_x=None, exp_minus_x=None, sigm_x=None,
by replacing matching pairs (exp, sigmoid) with the desired optimized
version.
:param tree: The sub-tree to operate on.
:exp_x: List of arguments x so that `exp(x)` exists somewhere in the whole
multiplication tree. Each argument is a pair (x, leaf) with `x` the
argument of the exponential, and `leaf` the corresponding leaf in the
multiplication tree (of the form [n, exp(x)] -- see `parse_mul_tree`).
If None, this argument is initialized to an empty list.
:param exp_minus_x: Similar to `exp_x`, but for `exp(-x)`.
:param sigm_x: Similar to `exp_x`, but for `sigmoid(x)`.
Parameters
----------
tree
The sub-tree to operate on.
exp_x
List of arguments x so that `exp(x)` exists somewhere in the whole
multiplication tree. Each argument is a pair (x, leaf) with `x` the
argument of the exponential, and `leaf` the corresponding leaf in the
multiplication tree (of the form [n, exp(x)] -- see `parse_mul_tree`).
If None, this argument is initialized to an empty list.
exp_minus_x
Similar to `exp_x`, but for `exp(-x)`.
sigm_x
Similar to `exp_x`, but for `sigmoid(x)`.
sigm_minus_x
Similar to `exp_x`, but for `sigmoid(-x)`.
parent
Parent of `tree` (None if `tree` is the global root).
child_idx
Index of `tree` in its parent's inputs (None if `tree` is the global
root).
full_tree
The global multiplication tree (should not be set except by recursive
calls to this function). Used for debugging only.
Returns
-------
bool
True if a modification was performed somewhere in the whole multiplication
tree, or False otherwise.
:param sigm_minus_x: Similar to `exp_x`, but for `sigmoid(-x)`.
:param parent: Parent of `tree` (None if `tree` is the global root).
:param child_idx: Index of `tree` in its parent's inputs (None if `tree` is
the global root).
:param full_tree: The global multiplication tree (should not be set except
by recursive calls to this function). Used for debugging only.
:return: True if a modification was performed somewhere in the whole
multiplication tree, or False otherwise.
"""
if exp_x is None:
exp_x = []
if exp_minus_x is None:
......@@ -836,6 +901,7 @@ def local_sigm_times_exp(node):
"""
exp(x) * sigm(-x) -> sigm(x)
exp(-x) * sigm(x) -> sigm(-x)
"""
# Bail early if it is not a multiplication.
if node.op != tensor.mul:
......@@ -859,6 +925,7 @@ def local_sigm_times_exp(node):
def local_inv_1_plus_exp(node):
"""
1/(1+exp(x)) -> sigm(-x)
"""
# this optimization should be done for numerical stability
# so we don't care to check client counts
......@@ -883,6 +950,7 @@ def local_inv_1_plus_exp(node):
def local_1msigmoid(node):
"""
1-sigm(x) -> sigm(-x)
"""
if node.op == tensor.sub:
sub_l, sub_r = node.inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论