提交 ccfe2d3d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor aesara.gradient and add type hints

上级 1d369b55
"""Driver for gradient calculations.""" """Driver for gradient calculations."""
import logging
import time import time
import warnings import warnings
from collections import OrderedDict
from functools import partial, reduce from functools import partial, reduce
from typing import TYPE_CHECKING, Callable, List, Optional, Union from typing import (
TYPE_CHECKING,
Callable,
Dict,
List,
Mapping,
MutableSequence,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
import numpy as np import numpy as np
from typing_extensions import Literal
import aesara import aesara
from aesara.compile.ops import ViewOp from aesara.compile.ops import ViewOp
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import utils from aesara.graph import utils
from aesara.graph.basic import NominalVariable, Variable from aesara.graph.basic import Apply, NominalVariable, Variable
from aesara.graph.null_type import NullType, null_type from aesara.graph.null_type import NullType, null_type
from aesara.graph.op import get_test_values from aesara.graph.op import get_test_values
from aesara.graph.type import Type from aesara.graph.type import Type
...@@ -23,26 +34,18 @@ if TYPE_CHECKING: ...@@ -23,26 +34,18 @@ if TYPE_CHECKING:
from aesara.compile.mode import Mode from aesara.compile.mode import Mode
__docformat__ = "restructuredtext en" V = TypeVar("V", bound=Optional[Variable])
_logger = logging.getLogger("aesara.gradient")
# we can't do "import aesara.tensor"
# tensor depends on aesara.compile
# aesara.compile depends on aesara.gradient (this file)
# the reason aesara.compile depends on aesara.gradient
# is that aesara.compile.builders contains the op from graph
# functionality and it uses aesara.gradient to implement
# the new op's grad method
tensor = None
_msg_retType = "op.grad(...) returned a non-list" # TODO: Refactor this so that it's not a global variable
grad_time: float = 0.0
grad_time = 0
# TODO: Add `overload` variants
def format_as(use_list, use_tuple, outputs): def as_list_or_tuple(
""" use_list: bool, use_tuple: bool, outputs: Union[V, Sequence[V]]
Formats the outputs according to the flags `use_list` and `use_tuple`. ) -> Union[V, List[V], Tuple[V, ...]]:
"""Return either a single object or a list/tuple of objects.
If `use_list` is True, `outputs` is returned as a list (if `outputs` If `use_list` is True, `outputs` is returned as a list (if `outputs`
is not a list or a tuple then it is converted in a one element list). is not a list or a tuple then it is converted in a one element list).
...@@ -52,22 +55,25 @@ def format_as(use_list, use_tuple, outputs): ...@@ -52,22 +55,25 @@ def format_as(use_list, use_tuple, outputs):
""" """
if use_list and use_tuple: if use_list and use_tuple:
raise ValueError("Both flags cannot be simultaneously True") raise ValueError("Both flags cannot be simultaneously True")
if (use_list or use_tuple) and not isinstance(outputs, (list, tuple)):
if use_list: if use_list or use_tuple:
return [outputs] if isinstance(outputs, Sequence):
else: if use_list:
return (outputs,) return list(outputs)
elif not (use_list or use_tuple) and isinstance(outputs, (list, tuple)): else:
if len(outputs) != 1: return tuple(outputs)
raise ValueError("Wrong arguments; expected a one element list")
return outputs[0]
elif use_list or use_tuple:
if use_list:
return list(outputs)
else: else:
return tuple(outputs) if use_list:
return [outputs]
else:
return (outputs,)
else: else:
return outputs if isinstance(outputs, Sequence):
if len(outputs) != 1:
raise ValueError("Wrong arguments; expected a one element list")
return outputs[0]
else:
return outputs
def grad_not_implemented(op, x_pos, x, comment=""): def grad_not_implemented(op, x_pos, x, comment=""):
...@@ -155,97 +161,87 @@ class DisconnectedType(Type): ...@@ -155,97 +161,87 @@ class DisconnectedType(Type):
disconnected_type = DisconnectedType() disconnected_type = DisconnectedType()
######################## def Rop(
# R Operator f: Union[Variable, Sequence[Variable]],
######################## wrt: Union[Variable, Sequence[Variable]],
eval_points: Union[Variable, Sequence[Variable]],
disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise",
return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
) -> Union[Optional[Variable], Sequence[Optional[Variable]]]:
"""Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`.
def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="zero"): Mathematically this stands for the Jacobian of `f` right multiplied by the
""" `eval_points`.
Computes the R operation on `f` wrt to `wrt` at `eval_points`.
Mathematically this stands for the jacobian of `f` wrt
to `wrt` right muliplied by the eval points.
Parameters Parameters
---------- ----------
f : :class:`~aesara.graph.basic.Variable` or list of Variables f
`f` stands for the output of the computational graph to which you The outputs of the computational graph to which the R-operator is
want to apply the R operator applied.
wrt : :class:`~aesara.graph.basic.Variable` or list of Variables wrt
variables for which you compute the R operator of the expression Variables for which the R-operator of `f` is computed.
described by `f` eval_points
eval_points : :class:`~aesara.graph.basic.Variable` or list of Variables Points at which to evaluate each of the variables in `wrt`.
evaluation points for each of the variables in `wrt` disconnected_outputs
disconnected_outputs : str
Defines the behaviour if some of the variables in `f` Defines the behaviour if some of the variables in `f`
have no dependency on any of the variable in `wrt` (or if have no dependency on any of the variable in `wrt` (or if
all links are non-differentiable). The possible values are: all links are non-differentiable). The possible values are:
- 'ignore': considers that the gradient on these parameters is zero. - ``'ignore'``: considers that the gradient on these parameters is zero.
- 'warn': consider the gradient zero, and print a warning. - ``'warn'``: consider the gradient zero, and print a warning.
- 'raise': raise DisconnectedInputError. - ``'raise'``: raise `DisconnectedInputError`.
return_disconnected : {'zero', 'None', 'Disconnected'} return_disconnected
- 'zero' : If wrt[i] is disconnected, return value i will be - ``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be
wrt[i].zeros_like() ``wrt[i].zeros_like()``.
- 'None' : If wrt[i] is disconnected, return value i will be - ``'none'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be
None ``None``
- 'Disconnected' : returns variables of type DisconnectedType - ``'disconnected'`` : returns variables of type `DisconnectedType`
Returns Returns
------- -------
:class:`~aesara.graph.basic.Variable` or list/tuple of Variables depending on type of f A symbolic expression such obeying
Symbolic expression such that ``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``,
R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]
where the indices in that expression are magic multidimensional where the indices in that expression are magic multidimensional
indices that specify both the position within a list and all indices that specify both the position within a list and all
coordinates of the tensor element in the last. coordinates of the tensor elements.
If `wrt` is a list/tuple, then return a list/tuple with the results. If `wrt` is a list/tuple, then return a list/tuple with the results.
""" """
using_list = isinstance(f, list)
using_tuple = isinstance(f, tuple)
if not isinstance(wrt, (list, tuple)): if not isinstance(wrt, (list, tuple)):
wrt = [wrt] _wrt: List[Variable] = [aesara.tensor.as_tensor_variable(wrt)]
else:
_wrt = [aesara.tensor.as_tensor_variable(x) for x in wrt]
if not isinstance(eval_points, (list, tuple)): if not isinstance(eval_points, (list, tuple)):
eval_points = [eval_points] _eval_points: List[Variable] = [aesara.tensor.as_tensor_variable(eval_points)]
else:
_eval_points = [aesara.tensor.as_tensor_variable(x) for x in eval_points]
if not isinstance(f, (list, tuple)): if not isinstance(f, (list, tuple)):
f = [f] _f: List[Variable] = [aesara.tensor.as_tensor_variable(f)]
else:
_f = [aesara.tensor.as_tensor_variable(x) for x in f]
if len(wrt) != len(eval_points): if len(_wrt) != len(_eval_points):
raise ValueError("`wrt` must be the same length as `eval_points`.") raise ValueError("`wrt` must be the same length as `eval_points`.")
# Check that each element of wrt corresponds to an element # Check that each element of wrt corresponds to an element
# of eval_points with the same dimensionality. # of eval_points with the same dimensionality.
for pack in enumerate(zip(wrt, eval_points)): for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points)):
i = pack[0]
wrt_elem, eval_point = pack[1]
if not isinstance(wrt_elem, Variable):
wrt_elem = aesara.tensor.as_tensor_variable(wrt_elem)
if not isinstance(eval_point, Variable):
eval_point = aesara.tensor.as_tensor_variable(eval_point)
try: try:
if wrt_elem.type.ndim != eval_point.type.ndim: if wrt_elem.type.ndim != eval_point.type.ndim:
raise ValueError( raise ValueError(
"Element " f"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: "
+ str(i) f"{wrt_elem.type.ndim} and {eval_point.type.ndim}"
+ " of wrt/eval_point have mismatched "
+ "dimensionality: "
+ str(wrt_elem.type.ndim)
+ " versus "
+ str(eval_point.type.ndim)
) )
except AttributeError: except AttributeError:
# wrt_elem and eval_point don't always have ndim like random type # wrt_elem and eval_point don't always have ndim like random type
# Tensor, Sparse have the ndim attribute # Tensor, Sparse have the ndim attribute
pass pass
seen_nodes = OrderedDict() seen_nodes: Dict[Apply, Sequence[Variable]] = {}
def _traverse(node): def _traverse(node):
"""TODO: writeme""" """TODO: writeme"""
...@@ -260,8 +256,8 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected=" ...@@ -260,8 +256,8 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
# inputs of the node # inputs of the node
local_eval_points = [] local_eval_points = []
for inp in inputs: for inp in inputs:
if inp in wrt: if inp in _wrt:
local_eval_points.append(eval_points[wrt.index(inp)]) local_eval_points.append(_eval_points[_wrt.index(inp)])
elif inp.owner is None: elif inp.owner is None:
try: try:
local_eval_points.append(inp.zeros_like()) local_eval_points.append(inp.zeros_like())
...@@ -316,13 +312,13 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected=" ...@@ -316,13 +312,13 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
# end _traverse # end _traverse
# Populate the dictionary # Populate the dictionary
for out in f: for out in _f:
_traverse(out.owner) _traverse(out.owner)
rval = [] rval: List[Optional[Variable]] = []
for out in f: for out in _f:
if out in wrt: if out in _wrt:
rval.append(eval_points[wrt.index(out)]) rval.append(_eval_points[_wrt.index(out)])
elif ( elif (
seen_nodes.get(out.owner, None) is None seen_nodes.get(out.owner, None) is None
or seen_nodes[out.owner][out.owner.outputs.index(out)] is None or seen_nodes[out.owner][out.owner.outputs.index(out)] is None
...@@ -361,81 +357,89 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected=" ...@@ -361,81 +357,89 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
else: else:
rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)]) rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)])
return format_as(using_list, using_tuple, rval) using_list = isinstance(f, list)
using_tuple = isinstance(f, tuple)
return as_list_or_tuple(using_list, using_tuple, rval)
def Lop(f, wrt, eval_points, consider_constant=None, disconnected_inputs="raise"): def Lop(
"""Computes the L operation on `f` with respect to `wrt` at `eval_points`. f: Union[Variable, Sequence[Variable]],
wrt: Union[Variable, Sequence[Variable]],
eval_points: Union[Variable, Sequence[Variable]],
consider_constant: Optional[Sequence[Variable]] = None,
disconnected_inputs: Literal["ignore", "warn", "raise"] = "raise",
) -> Union[Optional[Variable], Sequence[Optional[Variable]]]:
"""Computes the L-operator applied to `f` with respect to `wrt` at `eval_points`.
Mathematically this stands for the Jacobian of `f` with respect to `wrt` Mathematically this stands for the Jacobian of `f` with respect to `wrt`
left muliplied by the `eval_points`. left muliplied by the `eval_points`.
Parameters Parameters
---------- ----------
f : :class:`~aesara.graph.basic.Variable` or list of Variables f
`f` stands for the output of the computational graph to which you The outputs of the computational graph to which the R-operator is
want to apply the L operator applied.
wrt : :class:`~aesara.graph.basic.Variable` or list of Variables wrt
variables for which you compute the L operator of the expression Variables for which the R-operator of `f` is computed.
described by `f` eval_points
eval_points : :class:`~aesara.graph.basic.Variable` or list of Variables Points at which to evaluate each of the variables in `wrt`.
evaluation points for each of the variables in `f` consider_constant
See `grad`.
disconnected_inputs
See `grad`.
Returns Returns
------- -------
:class:`~aesara.graph.basic.Variable` or list/tuple of Variables depending on type of `f` A symbolic expression satisfying
Symbolic expression such that
``L_op[i] = sum_i (d f[i] / d wrt[j]) eval_point[i]`` ``L_op[i] = sum_i (d f[i] / d wrt[j]) eval_point[i]``
where the indices in that expression are magic multidimensional where the indices in that expression are magic multidimensional
indices that specify both the position within a list and all indices that specify both the position within a list and all
coordinates of the tensor element in the last coordinates of the tensor elements.
If `f` is a list/tuple, then return a list/tuple with the results. If `f` is a list/tuple, then return a list/tuple with the results.
""" """
if not isinstance(eval_points, (list, tuple)): if not isinstance(eval_points, (list, tuple)):
eval_points = [eval_points] _eval_points: List[Variable] = [aesara.tensor.as_tensor_variable(eval_points)]
else:
using_list = isinstance(wrt, list) _eval_points = [aesara.tensor.as_tensor_variable(x) for x in eval_points]
using_tuple = isinstance(wrt, tuple)
if not isinstance(f, (list, tuple)): if not isinstance(f, (list, tuple)):
f = [f] _f: List[Variable] = [aesara.tensor.as_tensor_variable(f)]
else:
_f = [aesara.tensor.as_tensor_variable(x) for x in f]
# make copies of f and grads so we don't modify the client's copy grads = list(_eval_points)
f = list(f)
grads = list(eval_points)
if not isinstance(wrt, (list, tuple)): if not isinstance(wrt, (list, tuple)):
wrt = [wrt] _wrt: List[Variable] = [aesara.tensor.as_tensor_variable(wrt)]
else:
_wrt = [aesara.tensor.as_tensor_variable(x) for x in wrt]
assert len(f) == len(grads) assert len(_f) == len(grads)
known = OrderedDict(zip(f, grads)) known = dict(zip(_f, grads))
ret = grad( ret = grad(
cost=None, cost=None,
known_grads=known, known_grads=known,
consider_constant=consider_constant, consider_constant=consider_constant,
wrt=wrt, wrt=_wrt,
disconnected_inputs=disconnected_inputs, disconnected_inputs=disconnected_inputs,
) )
return format_as(using_list, using_tuple, ret) using_list = isinstance(wrt, list)
using_tuple = isinstance(wrt, tuple)
return as_list_or_tuple(using_list, using_tuple, ret)
#########################
# Gradient
#########################
def grad( def grad(
cost, cost: Optional[Variable],
wrt, wrt: Union[Variable, Sequence[Variable]],
consider_constant=None, consider_constant: Optional[Sequence[Variable]] = None,
disconnected_inputs="raise", disconnected_inputs: Literal["ignore", "warn", "raise"] = "raise",
add_names=True, add_names: bool = True,
known_grads=None, known_grads: Optional[Mapping[Variable, Variable]] = None,
return_disconnected="zero", return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
null_gradients="raise", null_gradients: Literal["raise", "return"] = "raise",
): ) -> Union[Optional[Variable], Sequence[Optional[Variable]]]:
""" """
Return symbolic gradients of one cost with respect to one or more variables. Return symbolic gradients of one cost with respect to one or more variables.
...@@ -445,49 +449,47 @@ def grad( ...@@ -445,49 +449,47 @@ def grad(
Parameters Parameters
---------- ----------
cost : :class:`~aesara.graph.basic.Variable` scalar (0-dimensional) tensor variable or ``None`` cost
Value that we are differentiating (that we want the gradient of). Value that we are differentiating (i.e. for which we want the
May be `None` if `known_grads` is provided. gradient). May be `None` if `known_grads` is provided.
wrt : :class:`~aesara.graph.basic.Variable` or list of Variables wrt
Term[s] with respect to which we want gradients The term(s) with respect to which we want gradients.
consider_constant : list of variables consider_constant
Expressions not to backpropagate through Expressions not to backpropagate through.
disconnected_inputs : {'ignore', 'warn', 'raise'} disconnected_inputs : {'ignore', 'warn', 'raise'}
Defines the behaviour if some of the variables in `wrt` are Defines the behaviour if some of the variables in `wrt` are
not part of the computational graph computing `cost` (or if not part of the computational graph computing `cost` (or if
all links are non-differentiable). The possible values are: all links are non-differentiable). The possible values are:
- 'ignore': considers that the gradient on these parameters is zero. - ``'ignore'``: considers that the gradient on these parameters is zero
- 'warn': consider the gradient zero, and print a warning. - ``'warn'``: consider the gradient zero, and print a warning
- 'raise': raise DisconnectedInputError. - ``'raise'``: raise `DisconnectedInputError`
add_names : bool add_names
If True, variables generated by grad will be named If ``True``, variables generated by `grad` will be named
(d<cost.name>/d<wrt.name>) provided that both cost and wrt ``(d<cost.name>/d<wrt.name>)`` provided that both `cost` and `wrt`
have names have names.
known_grads : OrderedDict, optional known_grads
A ordered dictionary mapping variables to their gradients. This is An ordered dictionary mapping variables to their gradients. This is
useful in the case where you know the gradient on some useful in the case where you know the gradients of some
variables but do not know the original cost. variables but do not know the original cost.
return_disconnected : {'zero', 'None', 'Disconnected'} return_disconnected
- 'zero' : If wrt[i] is disconnected, return value i will be - ``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be
wrt[i].zeros_like() ``wrt[i].zeros_like()``
- 'None' : If wrt[i] is disconnected, return value i will be - ``'none'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be
None ``None``
- 'Disconnected' : returns variables of type DisconnectedType - ``'disconnected'`` : returns variables of type `DisconnectedType`
null_gradients : {'raise', 'return'} null_gradients
Defines the behaviour if some of the variables in `wrt` have a Defines the behaviour when some of the variables in `wrt` have a
null gradient. The possibles values are: null gradient. The possibles values are:
- 'raise' : raise a NullTypeGradError exception - ``'raise'`` : raise a `NullTypeGradError` exception
- 'return' : return the null gradients - ``'return'`` : return the null gradients
Returns Returns
------- -------
variable or list/tuple of variables (matches `wrt`) A symbolic expression for the gradient of `cost` with respect to each
Symbolic expression of gradient of `cost` with respect to each of the `wrt` terms. If an element of `wrt` is not differentiable with
of the `wrt` terms. If an element of `wrt` is not respect to the output, then a zero variable is returned.
differentiable with respect to the output, then a zero
variable is returned.
""" """
t0 = time.time() t0 = time.time()
...@@ -498,30 +500,17 @@ def grad( ...@@ -498,30 +500,17 @@ def grad(
if cost is not None and isinstance(cost.type, NullType): if cost is not None and isinstance(cost.type, NullType):
raise ValueError( raise ValueError(
"Can't differentiate a NaN cost." "Can't differentiate a NaN cost. "
"cost is NaN because " + cost.type.why_null f"Cost is NaN because {cost.type.why_null}"
)
if cost is not None and cost.ndim != 0:
raise TypeError("cost must be a scalar.")
if isinstance(wrt, set):
raise TypeError(
"wrt must not be a set. sets have no defined "
"iteration order, so we can't return gradients in a"
" matching order."
) )
using_list = isinstance(wrt, list) if cost is not None and cost.type.ndim != 0:
using_tuple = isinstance(wrt, tuple) raise TypeError("Cost must be a scalar.")
if not using_list and not using_tuple:
wrt = [wrt]
for elem in wrt: if not isinstance(wrt, Sequence):
if not isinstance(elem, Variable): _wrt: List[Variable] = [wrt]
raise TypeError( else:
"Expected Variable, got " + str(elem) + " of type " + str(type(elem)) _wrt = [x for x in wrt]
)
outputs = [] outputs = []
if cost is not None: if cost is not None:
...@@ -529,16 +518,15 @@ def grad( ...@@ -529,16 +518,15 @@ def grad(
if known_grads is not None: if known_grads is not None:
outputs.extend(list(known_grads.keys())) outputs.extend(list(known_grads.keys()))
var_to_app_to_idx = _populate_var_to_app_to_idx(outputs, wrt, consider_constant) var_to_app_to_idx = _populate_var_to_app_to_idx(outputs, _wrt, consider_constant)
# build a dict mapping var to the gradient of cost with respect to var # build a dict mapping var to the gradient of cost with respect to var
grad_dict = OrderedDict() grad_dict = {}
if known_grads is None: if known_grads is None:
known_grads = OrderedDict() known_grads = {}
else:
m = "known_grads must be an OrderedDict. " assert isinstance(known_grads, dict)
assert isinstance(known_grads, OrderedDict) or len(known_grads) <= 1, m
# The gradient of the cost is 1 unless specified otherwise by known_grads. # The gradient of the cost is 1 unless specified otherwise by known_grads.
if cost is not None: if cost is not None:
...@@ -615,7 +603,7 @@ def grad( ...@@ -615,7 +603,7 @@ def grad(
# if wrt is such a variable, populate the grad_dict with this info # if wrt is such a variable, populate the grad_dict with this info
# so that wrt not being in var_to_app_to_idx won't cause an error below # so that wrt not being in var_to_app_to_idx won't cause an error below
# according to the flag, possibly raise an error if wrt is disconnected # according to the flag, possibly raise an error if wrt is disconnected
for elem in wrt: for elem in _wrt:
if elem not in var_to_app_to_idx and elem is not cost and elem not in grad_dict: if elem not in var_to_app_to_idx and elem is not cost and elem not in grad_dict:
handle_disconnected(elem) handle_disconnected(elem)
grad_dict[elem] = disconnected_type() grad_dict[elem] = disconnected_type()
...@@ -632,32 +620,38 @@ def grad( ...@@ -632,32 +620,38 @@ def grad(
if hasattr(g.type, "dtype"): if hasattr(g.type, "dtype"):
assert g.type.dtype in aesara.tensor.type.float_dtypes assert g.type.dtype in aesara.tensor.type.float_dtypes
rval = _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name) _rval: Sequence[Variable] = _populate_grad_dict(
var_to_app_to_idx, grad_dict, _wrt, cost_name
)
rval: MutableSequence[Optional[Variable]] = list(_rval)
for i in range(len(rval)): for i in range(len(_rval)):
if isinstance(rval[i].type, NullType): if isinstance(_rval[i].type, NullType):
if null_gradients == "raise": if null_gradients == "raise":
raise NullTypeGradError( raise NullTypeGradError(
f"grad encountered a NaN. {rval[i].type.why_null}" f"`grad` encountered a NaN. {_rval[i].type.why_null}"
) )
else: else:
assert null_gradients == "return" assert null_gradients == "return"
if isinstance(rval[i].type, DisconnectedType): if isinstance(_rval[i].type, DisconnectedType):
handle_disconnected(rval[i]) handle_disconnected(_rval[i])
if return_disconnected == "zero": if return_disconnected == "zero":
rval[i] = _float_zeros_like(wrt[i]) rval[i] = _float_zeros_like(_wrt[i])
elif return_disconnected == "None": elif return_disconnected.lower() == "none":
rval[i] = None rval[i] = None
else: else:
assert return_disconnected == "Disconnected" assert return_disconnected.lower() == "disconnected"
if using_tuple:
rval = tuple(rval)
elif not using_list:
(rval,) = rval
t1 = time.time() t1 = time.time()
global grad_time global grad_time
grad_time += t1 - t0 grad_time += t1 - t0
if isinstance(wrt, tuple):
return tuple(rval)
elif not isinstance(wrt, list):
return rval[0]
return rval return rval
...@@ -801,7 +795,7 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False): ...@@ -801,7 +795,7 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
for i in range(len(grads)): for i in range(len(grads)):
grads[i] += cost_grads[i] grads[i] += cost_grads[i]
pgrads = OrderedDict(zip(params, grads)) pgrads = dict(zip(params, grads))
# separate wrt from end grads: # separate wrt from end grads:
wrt_grads = list(pgrads[k] for k in wrt) wrt_grads = list(pgrads[k] for k in wrt)
end_grads = list(pgrads[k] for k in end) end_grads = list(pgrads[k] for k in end)
...@@ -916,7 +910,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant): ...@@ -916,7 +910,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
# var_to_app_to_idx[var][node] = [i,j] means node has # var_to_app_to_idx[var][node] = [i,j] means node has
# var as input at positions i and j # var as input at positions i and j
var_to_app_to_idx = OrderedDict() var_to_app_to_idx = dict()
# Set of variables that have been added to their true parents # Set of variables that have been added to their true parents
# ('true' here means that the elements of the variable are a function # ('true' here means that the elements of the variable are a function
...@@ -954,13 +948,13 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant): ...@@ -954,13 +948,13 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
continue continue
if ipt not in var_to_app_to_idx: if ipt not in var_to_app_to_idx:
# This object here *must* be an OrderedDict, because # This object here *must* be ordered, because
# we iterate over its keys when adding up the terms of the # we iterate over its keys when adding up the terms of the
# gradient on ipt. If it is a regular dict, the grad method # gradient on ipt. If it is a regular dict, the grad method
# will return something that is analytically correct, but # will return something that is analytically correct, but
# whose order of doing additions depends on the memory # whose order of doing additions depends on the memory
# location of the apply nodes. # location of the apply nodes.
var_to_app_to_idx[ipt] = OrderedDict() var_to_app_to_idx[ipt] = {}
app_to_idx = var_to_app_to_idx[ipt] app_to_idx = var_to_app_to_idx[ipt]
if app not in app_to_idx: if app not in app_to_idx:
app_to_idx[app] = [] app_to_idx[app] = []
...@@ -1052,7 +1046,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): ...@@ -1052,7 +1046,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
""" """
# build a dict mapping node to the terms node contributes to each of # build a dict mapping node to the terms node contributes to each of
# its inputs' gradients # its inputs' gradients
term_dict = OrderedDict() term_dict = {}
def access_term_cache(node): def access_term_cache(node):
"""Populates term_dict[node] and returns it""" """Populates term_dict[node] and returns it"""
...@@ -1978,7 +1972,7 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise ...@@ -1978,7 +1972,7 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
if expression.ndim == 0: if expression.ndim == 0:
# expression is just a scalar, use grad # expression is just a scalar, use grad
return format_as( return as_list_or_tuple(
using_list, using_list,
using_tuple, using_tuple,
grad( grad(
...@@ -2013,7 +2007,7 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise ...@@ -2013,7 +2007,7 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
non_sequences=[expression] + wrt, non_sequences=[expression] + wrt,
) )
assert not updates, "Scan has returned a list of updates; this should not happen." assert not updates, "Scan has returned a list of updates; this should not happen."
return format_as(using_list, using_tuple, jacobs) return as_list_or_tuple(using_list, using_tuple, jacobs)
def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"): def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
...@@ -2093,7 +2087,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"): ...@@ -2093,7 +2087,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
not updates not updates
), "Scan has returned a list of updates; this should not happen." ), "Scan has returned a list of updates; this should not happen."
hessians.append(hess) hessians.append(hess)
return format_as(using_list, using_tuple, hessians) return as_list_or_tuple(using_list, using_tuple, hessians)
def _is_zero(x): def _is_zero(x):
...@@ -2134,7 +2128,6 @@ class ConsiderConstant(ViewOp): ...@@ -2134,7 +2128,6 @@ class ConsiderConstant(ViewOp):
consider_constant_ = ConsiderConstant() consider_constant_ = ConsiderConstant()
# I create a function only to have the doc show well.
def consider_constant(x): def consider_constant(x):
""" """
DEPRECATED: use zero_grad() or disconnected_grad() instead. DEPRECATED: use zero_grad() or disconnected_grad() instead.
......
...@@ -278,8 +278,6 @@ class TestGrad: ...@@ -278,8 +278,6 @@ class TestGrad:
g = grad(a1.outputs[0], a1.outputs[1], disconnected_inputs="ignore") g = grad(a1.outputs[0], a1.outputs[1], disconnected_inputs="ignore")
assert g.owner.op == at.fill assert g.owner.op == at.fill
assert g.owner.inputs[1].data == 0 assert g.owner.inputs[1].data == 0
with pytest.raises(TypeError):
grad(a1.outputs[0], "wtf")
def test_NNone_rval(self): def test_NNone_rval(self):
# grad: Test returning some zero value from grad # grad: Test returning some zero value from grad
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论