提交 2a658f43 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

function to compute the jacobian

上级 30c24c78
......@@ -710,3 +710,63 @@ class GradientError(Exception):
args_msg)
verify_grad.E_grad = GradientError
def jacobian(expression, wrt, consider_constant=None, warn_type=False,
disconnected_inputs='raise'):
"""
:type expression: Vector (1-dimensional) `Variable`
:type wrt: 'Variable' or list of `Variables`s
:param consider_constant: a list of expressions not to backpropagate
through
:param warn_type: a value of True will cause warnings to be logged for any
Op that emits a gradient that does not match its input type.
:type disconnected_inputs: string
:param disconnected_inputs: Defines the behaviour if some of the variables
in ``wrt`` are not part of the computational graph computing ``cost``
(or if all links are non-differentiable). The possible values are:
- 'ignore': considers that the gradient on these parameters is zero.
- 'warn': consider the gradient zero, and print a warning.
- 'raise': raise an exception.
:return: either a instance of `Variable` or list/tuple of `Variable`s
(depending upon `wrt`). If an element of `wrt` is not
differentiable with respect to the output, then a zero
variable is returned. The return value is of same type
as `wrt`: a list/tuple or TensorVariable in all cases.
"""
# Check inputs have the right format
assert isisntance(expression, TensorVariable), \
"tensor.jacobian expects a Tensor Variable as `expression`"
assert expression.ndim == 1, \
"tensor.jacobian expects a 1 dimensional variable as `expression`"
if isintance(wrt, (list, tuple)):
use_list = True
wrt = list(wrt)
else:
use_list = False
wrt = [wrt]
def inner_function(*args):
idx = args[0]
expr = args[1]
return [grad(exp[idx],
inp,
consider_constant=consider_constant,
warn_type=warn_type,
disconnected_inputs=disconnected_inputs)
for inp in [args[2:]]]
# Computing the gradients does not affect the random seeds on any random
# generator used n expression (because during computing gradients we are
# just backtracking over old values. (rp Jan 2012 - if anyone has a
# counter example please show me)
jacobs, _ = scan(inner_function,
sequences=arange(expression.shape[0]),
non_sequences=[expression] + wrt)
if not use_list:
jacobs = jacobs[0]
return jacobs
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论