Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ccfe2d3d
提交
ccfe2d3d
authored
7月 05, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
7月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor aesara.gradient and add type hints
上级
1d369b55
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
211 行增加
和
220 行删除
+211
-220
gradient.py
aesara/gradient.py
+211
-218
test_gradient.py
tests/test_gradient.py
+0
-2
没有找到文件。
aesara/gradient.py
浏览文件 @
ccfe2d3d
"""Driver for gradient calculations."""
"""Driver for gradient calculations."""
import
logging
import
time
import
time
import
warnings
import
warnings
from
collections
import
OrderedDict
from
functools
import
partial
,
reduce
from
functools
import
partial
,
reduce
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
,
Union
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Dict
,
List
,
Mapping
,
MutableSequence
,
Optional
,
Sequence
,
Tuple
,
TypeVar
,
Union
,
)
import
numpy
as
np
import
numpy
as
np
from
typing_extensions
import
Literal
import
aesara
import
aesara
from
aesara.compile.ops
import
ViewOp
from
aesara.compile.ops
import
ViewOp
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph
import
utils
from
aesara.graph
import
utils
from
aesara.graph.basic
import
NominalVariable
,
Variable
from
aesara.graph.basic
import
Apply
,
NominalVariable
,
Variable
from
aesara.graph.null_type
import
NullType
,
null_type
from
aesara.graph.null_type
import
NullType
,
null_type
from
aesara.graph.op
import
get_test_values
from
aesara.graph.op
import
get_test_values
from
aesara.graph.type
import
Type
from
aesara.graph.type
import
Type
...
@@ -23,26 +34,18 @@ if TYPE_CHECKING:
...
@@ -23,26 +34,18 @@ if TYPE_CHECKING:
from
aesara.compile.mode
import
Mode
from
aesara.compile.mode
import
Mode
__docformat__
=
"restructuredtext en"
V
=
TypeVar
(
"V"
,
bound
=
Optional
[
Variable
])
_logger
=
logging
.
getLogger
(
"aesara.gradient"
)
# we can't do "import aesara.tensor"
# tensor depends on aesara.compile
# aesara.compile depends on aesara.gradient (this file)
# the reason aesara.compile depends on aesara.gradient
# is that aesara.compile.builders contains the op from graph
# functionality and it uses aesara.gradient to implement
# the new op's grad method
tensor
=
None
_msg_retType
=
"op.grad(...) returned a non-list"
# TODO: Refactor this so that it's not a global variable
grad_time
:
float
=
0.0
grad_time
=
0
# TODO: Add `overload` variants
def
format_as
(
use_list
,
use_tuple
,
outputs
):
def
as_list_or_tuple
(
"""
use_list
:
bool
,
use_tuple
:
bool
,
outputs
:
Union
[
V
,
Sequence
[
V
]]
Formats the outputs according to the flags `use_list` and `use_tuple`.
)
->
Union
[
V
,
List
[
V
],
Tuple
[
V
,
...
]]:
"""Return either a single object or a list/tuple of objects.
If `use_list` is True, `outputs` is returned as a list (if `outputs`
If `use_list` is True, `outputs` is returned as a list (if `outputs`
is not a list or a tuple then it is converted in a one element list).
is not a list or a tuple then it is converted in a one element list).
...
@@ -52,22 +55,25 @@ def format_as(use_list, use_tuple, outputs):
...
@@ -52,22 +55,25 @@ def format_as(use_list, use_tuple, outputs):
"""
"""
if
use_list
and
use_tuple
:
if
use_list
and
use_tuple
:
raise
ValueError
(
"Both flags cannot be simultaneously True"
)
raise
ValueError
(
"Both flags cannot be simultaneously True"
)
if
(
use_list
or
use_tuple
)
and
not
isinstance
(
outputs
,
(
list
,
tuple
)):
if
use_list
:
if
use_list
or
use_tuple
:
return
[
outputs
]
if
isinstance
(
outputs
,
Sequence
):
else
:
if
use_list
:
return
(
outputs
,)
return
list
(
outputs
)
elif
not
(
use_list
or
use_tuple
)
and
isinstance
(
outputs
,
(
list
,
tuple
)):
else
:
if
len
(
outputs
)
!=
1
:
return
tuple
(
outputs
)
raise
ValueError
(
"Wrong arguments; expected a one element list"
)
return
outputs
[
0
]
elif
use_list
or
use_tuple
:
if
use_list
:
return
list
(
outputs
)
else
:
else
:
return
tuple
(
outputs
)
if
use_list
:
return
[
outputs
]
else
:
return
(
outputs
,)
else
:
else
:
return
outputs
if
isinstance
(
outputs
,
Sequence
):
if
len
(
outputs
)
!=
1
:
raise
ValueError
(
"Wrong arguments; expected a one element list"
)
return
outputs
[
0
]
else
:
return
outputs
def
grad_not_implemented
(
op
,
x_pos
,
x
,
comment
=
""
):
def
grad_not_implemented
(
op
,
x_pos
,
x
,
comment
=
""
):
...
@@ -155,97 +161,87 @@ class DisconnectedType(Type):
...
@@ -155,97 +161,87 @@ class DisconnectedType(Type):
disconnected_type
=
DisconnectedType
()
disconnected_type
=
DisconnectedType
()
########################
def
Rop
(
# R Operator
f
:
Union
[
Variable
,
Sequence
[
Variable
]],
########################
wrt
:
Union
[
Variable
,
Sequence
[
Variable
]],
eval_points
:
Union
[
Variable
,
Sequence
[
Variable
]],
disconnected_outputs
:
Literal
[
"ignore"
,
"warn"
,
"raise"
]
=
"raise"
,
return_disconnected
:
Literal
[
"none"
,
"zero"
,
"disconnected"
]
=
"zero"
,
)
->
Union
[
Optional
[
Variable
],
Sequence
[
Optional
[
Variable
]]]:
"""Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`.
def
Rop
(
f
,
wrt
,
eval_points
,
disconnected_outputs
=
"raise"
,
return_disconnected
=
"zero"
):
Mathematically this stands for the Jacobian of `f` right multiplied by the
"""
`eval_points`.
Computes the R operation on `f` wrt to `wrt` at `eval_points`.
Mathematically this stands for the jacobian of `f` wrt
to `wrt` right muliplied by the eval points.
Parameters
Parameters
----------
----------
f : :class:`~aesara.graph.basic.Variable` or list of Variables
f
`f` stands for the output of the computational graph to which you
The outputs of the computational graph to which the R-operator is
want to apply the R operator
applied.
wrt : :class:`~aesara.graph.basic.Variable` or list of Variables
wrt
variables for which you compute the R operator of the expression
Variables for which the R-operator of `f` is computed.
described by `f`
eval_points
eval_points : :class:`~aesara.graph.basic.Variable` or list of Variables
Points at which to evaluate each of the variables in `wrt`.
evaluation points for each of the variables in `wrt`
disconnected_outputs
disconnected_outputs : str
Defines the behaviour if some of the variables in `f`
Defines the behaviour if some of the variables in `f`
have no dependency on any of the variable in `wrt` (or if
have no dependency on any of the variable in `wrt` (or if
all links are non-differentiable). The possible values are:
all links are non-differentiable). The possible values are:
-
'ignore'
: considers that the gradient on these parameters is zero.
-
``'ignore'``
: considers that the gradient on these parameters is zero.
-
'warn'
: consider the gradient zero, and print a warning.
-
``'warn'``
: consider the gradient zero, and print a warning.
-
'raise': raise DisconnectedInputError
.
-
``'raise'``: raise `DisconnectedInputError`
.
return_disconnected
: {'zero', 'None', 'Disconnected'}
return_disconnected
-
'zero' : If wrt[i] is disconnected, return value i
will be
-
``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i``
will be
wrt[i].zeros_like()
``wrt[i].zeros_like()``.
-
'None' : If wrt[i] is disconnected, return value i
will be
-
``'none'`` : If ``wrt[i]`` is disconnected, return value ``i``
will be
None
``None``
-
'Disconnected' : returns variables of type DisconnectedType
-
``'disconnected'`` : returns variables of type `DisconnectedType`
Returns
Returns
-------
-------
:class:`~aesara.graph.basic.Variable` or list/tuple of Variables depending on type of f
A symbolic expression such obeying
Symbolic expression such that
``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``,
R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]
where the indices in that expression are magic multidimensional
where the indices in that expression are magic multidimensional
indices that specify both the position within a list and all
indices that specify both the position within a list and all
coordinates of the tensor element
in the last
.
coordinates of the tensor element
s
.
If `wrt` is a list/tuple, then return a list/tuple with the results.
If `wrt` is a list/tuple, then return a list/tuple with the results.
"""
"""
using_list
=
isinstance
(
f
,
list
)
using_tuple
=
isinstance
(
f
,
tuple
)
if
not
isinstance
(
wrt
,
(
list
,
tuple
)):
if
not
isinstance
(
wrt
,
(
list
,
tuple
)):
wrt
=
[
wrt
]
_wrt
:
List
[
Variable
]
=
[
aesara
.
tensor
.
as_tensor_variable
(
wrt
)]
else
:
_wrt
=
[
aesara
.
tensor
.
as_tensor_variable
(
x
)
for
x
in
wrt
]
if
not
isinstance
(
eval_points
,
(
list
,
tuple
)):
if
not
isinstance
(
eval_points
,
(
list
,
tuple
)):
eval_points
=
[
eval_points
]
_eval_points
:
List
[
Variable
]
=
[
aesara
.
tensor
.
as_tensor_variable
(
eval_points
)]
else
:
_eval_points
=
[
aesara
.
tensor
.
as_tensor_variable
(
x
)
for
x
in
eval_points
]
if
not
isinstance
(
f
,
(
list
,
tuple
)):
if
not
isinstance
(
f
,
(
list
,
tuple
)):
f
=
[
f
]
_f
:
List
[
Variable
]
=
[
aesara
.
tensor
.
as_tensor_variable
(
f
)]
else
:
_f
=
[
aesara
.
tensor
.
as_tensor_variable
(
x
)
for
x
in
f
]
if
len
(
wrt
)
!=
len
(
eval_points
):
if
len
(
_wrt
)
!=
len
(
_
eval_points
):
raise
ValueError
(
"`wrt` must be the same length as `eval_points`."
)
raise
ValueError
(
"`wrt` must be the same length as `eval_points`."
)
# Check that each element of wrt corresponds to an element
# Check that each element of wrt corresponds to an element
# of eval_points with the same dimensionality.
# of eval_points with the same dimensionality.
for
pack
in
enumerate
(
zip
(
wrt
,
eval_points
)):
for
i
,
(
wrt_elem
,
eval_point
)
in
enumerate
(
zip
(
_wrt
,
_eval_points
)):
i
=
pack
[
0
]
wrt_elem
,
eval_point
=
pack
[
1
]
if
not
isinstance
(
wrt_elem
,
Variable
):
wrt_elem
=
aesara
.
tensor
.
as_tensor_variable
(
wrt_elem
)
if
not
isinstance
(
eval_point
,
Variable
):
eval_point
=
aesara
.
tensor
.
as_tensor_variable
(
eval_point
)
try
:
try
:
if
wrt_elem
.
type
.
ndim
!=
eval_point
.
type
.
ndim
:
if
wrt_elem
.
type
.
ndim
!=
eval_point
.
type
.
ndim
:
raise
ValueError
(
raise
ValueError
(
"Element "
f
"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: "
+
str
(
i
)
f
"{wrt_elem.type.ndim} and {eval_point.type.ndim}"
+
" of wrt/eval_point have mismatched "
+
"dimensionality: "
+
str
(
wrt_elem
.
type
.
ndim
)
+
" versus "
+
str
(
eval_point
.
type
.
ndim
)
)
)
except
AttributeError
:
except
AttributeError
:
# wrt_elem and eval_point don't always have ndim like random type
# wrt_elem and eval_point don't always have ndim like random type
# Tensor, Sparse have the ndim attribute
# Tensor, Sparse have the ndim attribute
pass
pass
seen_nodes
=
OrderedDict
()
seen_nodes
:
Dict
[
Apply
,
Sequence
[
Variable
]]
=
{}
def
_traverse
(
node
):
def
_traverse
(
node
):
"""TODO: writeme"""
"""TODO: writeme"""
...
@@ -260,8 +256,8 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
...
@@ -260,8 +256,8 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
# inputs of the node
# inputs of the node
local_eval_points
=
[]
local_eval_points
=
[]
for
inp
in
inputs
:
for
inp
in
inputs
:
if
inp
in
wrt
:
if
inp
in
_
wrt
:
local_eval_points
.
append
(
eval_points
[
wrt
.
index
(
inp
)])
local_eval_points
.
append
(
_eval_points
[
_
wrt
.
index
(
inp
)])
elif
inp
.
owner
is
None
:
elif
inp
.
owner
is
None
:
try
:
try
:
local_eval_points
.
append
(
inp
.
zeros_like
())
local_eval_points
.
append
(
inp
.
zeros_like
())
...
@@ -316,13 +312,13 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
...
@@ -316,13 +312,13 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
# end _traverse
# end _traverse
# Populate the dictionary
# Populate the dictionary
for
out
in
f
:
for
out
in
_
f
:
_traverse
(
out
.
owner
)
_traverse
(
out
.
owner
)
rval
=
[]
rval
:
List
[
Optional
[
Variable
]]
=
[]
for
out
in
f
:
for
out
in
_
f
:
if
out
in
wrt
:
if
out
in
_
wrt
:
rval
.
append
(
eval_points
[
wrt
.
index
(
out
)])
rval
.
append
(
_eval_points
[
_
wrt
.
index
(
out
)])
elif
(
elif
(
seen_nodes
.
get
(
out
.
owner
,
None
)
is
None
seen_nodes
.
get
(
out
.
owner
,
None
)
is
None
or
seen_nodes
[
out
.
owner
][
out
.
owner
.
outputs
.
index
(
out
)]
is
None
or
seen_nodes
[
out
.
owner
][
out
.
owner
.
outputs
.
index
(
out
)]
is
None
...
@@ -361,81 +357,89 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
...
@@ -361,81 +357,89 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
else
:
else
:
rval
.
append
(
seen_nodes
[
out
.
owner
][
out
.
owner
.
outputs
.
index
(
out
)])
rval
.
append
(
seen_nodes
[
out
.
owner
][
out
.
owner
.
outputs
.
index
(
out
)])
return
format_as
(
using_list
,
using_tuple
,
rval
)
using_list
=
isinstance
(
f
,
list
)
using_tuple
=
isinstance
(
f
,
tuple
)
return
as_list_or_tuple
(
using_list
,
using_tuple
,
rval
)
def
Lop
(
f
,
wrt
,
eval_points
,
consider_constant
=
None
,
disconnected_inputs
=
"raise"
):
def
Lop
(
"""Computes the L operation on `f` with respect to `wrt` at `eval_points`.
f
:
Union
[
Variable
,
Sequence
[
Variable
]],
wrt
:
Union
[
Variable
,
Sequence
[
Variable
]],
eval_points
:
Union
[
Variable
,
Sequence
[
Variable
]],
consider_constant
:
Optional
[
Sequence
[
Variable
]]
=
None
,
disconnected_inputs
:
Literal
[
"ignore"
,
"warn"
,
"raise"
]
=
"raise"
,
)
->
Union
[
Optional
[
Variable
],
Sequence
[
Optional
[
Variable
]]]:
"""Computes the L-operator applied to `f` with respect to `wrt` at `eval_points`.
Mathematically this stands for the Jacobian of `f` with respect to `wrt`
Mathematically this stands for the Jacobian of `f` with respect to `wrt`
left muliplied by the `eval_points`.
left muliplied by the `eval_points`.
Parameters
Parameters
----------
----------
f : :class:`~aesara.graph.basic.Variable` or list of Variables
f
`f` stands for the output of the computational graph to which you
The outputs of the computational graph to which the R-operator is
want to apply the L operator
applied.
wrt : :class:`~aesara.graph.basic.Variable` or list of Variables
wrt
variables for which you compute the L operator of the expression
Variables for which the R-operator of `f` is computed.
described by `f`
eval_points
eval_points : :class:`~aesara.graph.basic.Variable` or list of Variables
Points at which to evaluate each of the variables in `wrt`.
evaluation points for each of the variables in `f`
consider_constant
See `grad`.
disconnected_inputs
See `grad`.
Returns
Returns
-------
-------
:class:`~aesara.graph.basic.Variable` or list/tuple of Variables depending on type of `f`
A symbolic expression satisfying
Symbolic expression such that
``L_op[i] = sum_i (d f[i] / d wrt[j]) eval_point[i]``
``L_op[i] = sum_i (d f[i] / d wrt[j]) eval_point[i]``
where the indices in that expression are magic multidimensional
where the indices in that expression are magic multidimensional
indices that specify both the position within a list and all
indices that specify both the position within a list and all
coordinates of the tensor element
in the last
coordinates of the tensor element
s.
If `f` is a list/tuple, then return a list/tuple with the results.
If `f` is a list/tuple, then return a list/tuple with the results.
"""
"""
if
not
isinstance
(
eval_points
,
(
list
,
tuple
)):
if
not
isinstance
(
eval_points
,
(
list
,
tuple
)):
eval_points
=
[
eval_points
]
_eval_points
:
List
[
Variable
]
=
[
aesara
.
tensor
.
as_tensor_variable
(
eval_points
)]
else
:
using_list
=
isinstance
(
wrt
,
list
)
_eval_points
=
[
aesara
.
tensor
.
as_tensor_variable
(
x
)
for
x
in
eval_points
]
using_tuple
=
isinstance
(
wrt
,
tuple
)
if
not
isinstance
(
f
,
(
list
,
tuple
)):
if
not
isinstance
(
f
,
(
list
,
tuple
)):
f
=
[
f
]
_f
:
List
[
Variable
]
=
[
aesara
.
tensor
.
as_tensor_variable
(
f
)]
else
:
_f
=
[
aesara
.
tensor
.
as_tensor_variable
(
x
)
for
x
in
f
]
# make copies of f and grads so we don't modify the client's copy
grads
=
list
(
_eval_points
)
f
=
list
(
f
)
grads
=
list
(
eval_points
)
if
not
isinstance
(
wrt
,
(
list
,
tuple
)):
if
not
isinstance
(
wrt
,
(
list
,
tuple
)):
wrt
=
[
wrt
]
_wrt
:
List
[
Variable
]
=
[
aesara
.
tensor
.
as_tensor_variable
(
wrt
)]
else
:
_wrt
=
[
aesara
.
tensor
.
as_tensor_variable
(
x
)
for
x
in
wrt
]
assert
len
(
f
)
==
len
(
grads
)
assert
len
(
_
f
)
==
len
(
grads
)
known
=
OrderedDict
(
zip
(
f
,
grads
))
known
=
dict
(
zip
(
_
f
,
grads
))
ret
=
grad
(
ret
=
grad
(
cost
=
None
,
cost
=
None
,
known_grads
=
known
,
known_grads
=
known
,
consider_constant
=
consider_constant
,
consider_constant
=
consider_constant
,
wrt
=
wrt
,
wrt
=
_
wrt
,
disconnected_inputs
=
disconnected_inputs
,
disconnected_inputs
=
disconnected_inputs
,
)
)
return
format_as
(
using_list
,
using_tuple
,
ret
)
using_list
=
isinstance
(
wrt
,
list
)
using_tuple
=
isinstance
(
wrt
,
tuple
)
return
as_list_or_tuple
(
using_list
,
using_tuple
,
ret
)
#########################
# Gradient
#########################
def
grad
(
def
grad
(
cost
,
cost
:
Optional
[
Variable
]
,
wrt
,
wrt
:
Union
[
Variable
,
Sequence
[
Variable
]]
,
consider_constant
=
None
,
consider_constant
:
Optional
[
Sequence
[
Variable
]]
=
None
,
disconnected_inputs
=
"raise"
,
disconnected_inputs
:
Literal
[
"ignore"
,
"warn"
,
"raise"
]
=
"raise"
,
add_names
=
True
,
add_names
:
bool
=
True
,
known_grads
=
None
,
known_grads
:
Optional
[
Mapping
[
Variable
,
Variable
]]
=
None
,
return_disconnected
=
"zero"
,
return_disconnected
:
Literal
[
"none"
,
"zero"
,
"disconnected"
]
=
"zero"
,
null_gradients
=
"raise"
,
null_gradients
:
Literal
[
"raise"
,
"return"
]
=
"raise"
,
):
)
->
Union
[
Optional
[
Variable
],
Sequence
[
Optional
[
Variable
]]]
:
"""
"""
Return symbolic gradients of one cost with respect to one or more variables.
Return symbolic gradients of one cost with respect to one or more variables.
...
@@ -445,49 +449,47 @@ def grad(
...
@@ -445,49 +449,47 @@ def grad(
Parameters
Parameters
----------
----------
cost
: :class:`~aesara.graph.basic.Variable` scalar (0-dimensional) tensor variable or ``None``
cost
Value that we are differentiating (
that we want the gradient of).
Value that we are differentiating (
i.e. for which we want the
May be `None` if `known_grads` is provided.
gradient).
May be `None` if `known_grads` is provided.
wrt
: :class:`~aesara.graph.basic.Variable` or list of Variables
wrt
T
erm[s] with respect to which we want gradients
T
he term(s) with respect to which we want gradients.
consider_constant
: list of variables
consider_constant
Expressions not to backpropagate through
Expressions not to backpropagate through
.
disconnected_inputs : {'ignore', 'warn', 'raise'}
disconnected_inputs : {'ignore', 'warn', 'raise'}
Defines the behaviour if some of the variables in `wrt` are
Defines the behaviour if some of the variables in `wrt` are
not part of the computational graph computing `cost` (or if
not part of the computational graph computing `cost` (or if
all links are non-differentiable). The possible values are:
all links are non-differentiable). The possible values are:
-
'ignore': considers that the gradient on these parameters is zero.
-
``'ignore'``: considers that the gradient on these parameters is zero
-
'warn': consider the gradient zero, and print a warning.
-
``'warn'``: consider the gradient zero, and print a warning
-
'raise': raise DisconnectedInputError.
-
``'raise'``: raise `DisconnectedInputError`
add_names
: bool
add_names
If
True, variables generated by grad
will be named
If
``True``, variables generated by `grad`
will be named
(d<cost.name>/d<wrt.name>) provided that both cost and wrt
``(d<cost.name>/d<wrt.name>)`` provided that both `cost` and `wrt`
have names
have names
.
known_grads
: OrderedDict, optional
known_grads
A ordered dictionary mapping variables to their gradients. This is
A
n
ordered dictionary mapping variables to their gradients. This is
useful in the case where you know the gradient
on
some
useful in the case where you know the gradient
s of
some
variables but do not know the original cost.
variables but do not know the original cost.
return_disconnected
: {'zero', 'None', 'Disconnected'}
return_disconnected
-
'zero' : If wrt[i] is disconnected, return value i
will be
-
``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i``
will be
wrt[i].zeros_like()
``wrt[i].zeros_like()``
-
'None' : If wrt[i] is disconnected, return value i
will be
-
``'none'`` : If ``wrt[i]`` is disconnected, return value ``i``
will be
None
``None``
-
'Disconnected' : returns variables of type DisconnectedType
-
``'disconnected'`` : returns variables of type `DisconnectedType`
null_gradients
: {'raise', 'return'}
null_gradients
Defines the behaviour
if
some of the variables in `wrt` have a
Defines the behaviour
when
some of the variables in `wrt` have a
null gradient. The possibles values are:
null gradient. The possibles values are:
-
'raise' : raise a NullTypeGradError
exception
-
``'raise'`` : raise a `NullTypeGradError`
exception
-
'return'
: return the null gradients
-
``'return'``
: return the null gradients
Returns
Returns
-------
-------
variable or list/tuple of variables (matches `wrt`)
A symbolic expression for the gradient of `cost` with respect to each
Symbolic expression of gradient of `cost` with respect to each
of the `wrt` terms. If an element of `wrt` is not differentiable with
of the `wrt` terms. If an element of `wrt` is not
respect to the output, then a zero variable is returned.
differentiable with respect to the output, then a zero
variable is returned.
"""
"""
t0
=
time
.
time
()
t0
=
time
.
time
()
...
@@ -498,30 +500,17 @@ def grad(
...
@@ -498,30 +500,17 @@ def grad(
if
cost
is
not
None
and
isinstance
(
cost
.
type
,
NullType
):
if
cost
is
not
None
and
isinstance
(
cost
.
type
,
NullType
):
raise
ValueError
(
raise
ValueError
(
"Can't differentiate a NaN cost."
"Can't differentiate a NaN cost. "
"cost is NaN because "
+
cost
.
type
.
why_null
f
"Cost is NaN because {cost.type.why_null}"
)
if
cost
is
not
None
and
cost
.
ndim
!=
0
:
raise
TypeError
(
"cost must be a scalar."
)
if
isinstance
(
wrt
,
set
):
raise
TypeError
(
"wrt must not be a set. sets have no defined "
"iteration order, so we can't return gradients in a"
" matching order."
)
)
using_list
=
isinstance
(
wrt
,
list
)
if
cost
is
not
None
and
cost
.
type
.
ndim
!=
0
:
using_tuple
=
isinstance
(
wrt
,
tuple
)
raise
TypeError
(
"Cost must be a scalar."
)
if
not
using_list
and
not
using_tuple
:
wrt
=
[
wrt
]
for
elem
in
wrt
:
if
not
isinstance
(
wrt
,
Sequence
):
if
not
isinstance
(
elem
,
Variable
):
_wrt
:
List
[
Variable
]
=
[
wrt
]
raise
TypeError
(
else
:
"Expected Variable, got "
+
str
(
elem
)
+
" of type "
+
str
(
type
(
elem
))
_wrt
=
[
x
for
x
in
wrt
]
)
outputs
=
[]
outputs
=
[]
if
cost
is
not
None
:
if
cost
is
not
None
:
...
@@ -529,16 +518,15 @@ def grad(
...
@@ -529,16 +518,15 @@ def grad(
if
known_grads
is
not
None
:
if
known_grads
is
not
None
:
outputs
.
extend
(
list
(
known_grads
.
keys
()))
outputs
.
extend
(
list
(
known_grads
.
keys
()))
var_to_app_to_idx
=
_populate_var_to_app_to_idx
(
outputs
,
wrt
,
consider_constant
)
var_to_app_to_idx
=
_populate_var_to_app_to_idx
(
outputs
,
_
wrt
,
consider_constant
)
# build a dict mapping var to the gradient of cost with respect to var
# build a dict mapping var to the gradient of cost with respect to var
grad_dict
=
OrderedDict
()
grad_dict
=
{}
if
known_grads
is
None
:
if
known_grads
is
None
:
known_grads
=
OrderedDict
()
known_grads
=
{}
else
:
m
=
"known_grads must be an OrderedDict. "
assert
isinstance
(
known_grads
,
dict
)
assert
isinstance
(
known_grads
,
OrderedDict
)
or
len
(
known_grads
)
<=
1
,
m
# The gradient of the cost is 1 unless specified otherwise by known_grads.
# The gradient of the cost is 1 unless specified otherwise by known_grads.
if
cost
is
not
None
:
if
cost
is
not
None
:
...
@@ -615,7 +603,7 @@ def grad(
...
@@ -615,7 +603,7 @@ def grad(
# if wrt is such a variable, populate the grad_dict with this info
# if wrt is such a variable, populate the grad_dict with this info
# so that wrt not being in var_to_app_to_idx won't cause an error below
# so that wrt not being in var_to_app_to_idx won't cause an error below
# according to the flag, possibly raise an error if wrt is disconnected
# according to the flag, possibly raise an error if wrt is disconnected
for
elem
in
wrt
:
for
elem
in
_
wrt
:
if
elem
not
in
var_to_app_to_idx
and
elem
is
not
cost
and
elem
not
in
grad_dict
:
if
elem
not
in
var_to_app_to_idx
and
elem
is
not
cost
and
elem
not
in
grad_dict
:
handle_disconnected
(
elem
)
handle_disconnected
(
elem
)
grad_dict
[
elem
]
=
disconnected_type
()
grad_dict
[
elem
]
=
disconnected_type
()
...
@@ -632,32 +620,38 @@ def grad(
...
@@ -632,32 +620,38 @@ def grad(
if
hasattr
(
g
.
type
,
"dtype"
):
if
hasattr
(
g
.
type
,
"dtype"
):
assert
g
.
type
.
dtype
in
aesara
.
tensor
.
type
.
float_dtypes
assert
g
.
type
.
dtype
in
aesara
.
tensor
.
type
.
float_dtypes
rval
=
_populate_grad_dict
(
var_to_app_to_idx
,
grad_dict
,
wrt
,
cost_name
)
_rval
:
Sequence
[
Variable
]
=
_populate_grad_dict
(
var_to_app_to_idx
,
grad_dict
,
_wrt
,
cost_name
)
rval
:
MutableSequence
[
Optional
[
Variable
]]
=
list
(
_rval
)
for
i
in
range
(
len
(
rval
)):
for
i
in
range
(
len
(
_
rval
)):
if
isinstance
(
rval
[
i
]
.
type
,
NullType
):
if
isinstance
(
_
rval
[
i
]
.
type
,
NullType
):
if
null_gradients
==
"raise"
:
if
null_gradients
==
"raise"
:
raise
NullTypeGradError
(
raise
NullTypeGradError
(
f
"
grad encountered a NaN. {
rval[i].type.why_null}"
f
"
`grad` encountered a NaN. {_
rval[i].type.why_null}"
)
)
else
:
else
:
assert
null_gradients
==
"return"
assert
null_gradients
==
"return"
if
isinstance
(
rval
[
i
]
.
type
,
DisconnectedType
):
if
isinstance
(
_
rval
[
i
]
.
type
,
DisconnectedType
):
handle_disconnected
(
rval
[
i
])
handle_disconnected
(
_
rval
[
i
])
if
return_disconnected
==
"zero"
:
if
return_disconnected
==
"zero"
:
rval
[
i
]
=
_float_zeros_like
(
wrt
[
i
])
rval
[
i
]
=
_float_zeros_like
(
_
wrt
[
i
])
elif
return_disconnected
==
"N
one"
:
elif
return_disconnected
.
lower
()
==
"n
one"
:
rval
[
i
]
=
None
rval
[
i
]
=
None
else
:
else
:
assert
return_disconnected
==
"D
isconnected"
assert
return_disconnected
.
lower
()
==
"d
isconnected"
if
using_tuple
:
rval
=
tuple
(
rval
)
elif
not
using_list
:
(
rval
,)
=
rval
t1
=
time
.
time
()
t1
=
time
.
time
()
global
grad_time
global
grad_time
grad_time
+=
t1
-
t0
grad_time
+=
t1
-
t0
if
isinstance
(
wrt
,
tuple
):
return
tuple
(
rval
)
elif
not
isinstance
(
wrt
,
list
):
return
rval
[
0
]
return
rval
return
rval
...
@@ -801,7 +795,7 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
...
@@ -801,7 +795,7 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
for
i
in
range
(
len
(
grads
)):
for
i
in
range
(
len
(
grads
)):
grads
[
i
]
+=
cost_grads
[
i
]
grads
[
i
]
+=
cost_grads
[
i
]
pgrads
=
OrderedD
ict
(
zip
(
params
,
grads
))
pgrads
=
d
ict
(
zip
(
params
,
grads
))
# separate wrt from end grads:
# separate wrt from end grads:
wrt_grads
=
list
(
pgrads
[
k
]
for
k
in
wrt
)
wrt_grads
=
list
(
pgrads
[
k
]
for
k
in
wrt
)
end_grads
=
list
(
pgrads
[
k
]
for
k
in
end
)
end_grads
=
list
(
pgrads
[
k
]
for
k
in
end
)
...
@@ -916,7 +910,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
...
@@ -916,7 +910,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
# var_to_app_to_idx[var][node] = [i,j] means node has
# var_to_app_to_idx[var][node] = [i,j] means node has
# var as input at positions i and j
# var as input at positions i and j
var_to_app_to_idx
=
OrderedD
ict
()
var_to_app_to_idx
=
d
ict
()
# Set of variables that have been added to their true parents
# Set of variables that have been added to their true parents
# ('true' here means that the elements of the variable are a function
# ('true' here means that the elements of the variable are a function
...
@@ -954,13 +948,13 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
...
@@ -954,13 +948,13 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
continue
continue
if
ipt
not
in
var_to_app_to_idx
:
if
ipt
not
in
var_to_app_to_idx
:
# This object here *must* be
an OrderedDict
, because
# This object here *must* be
ordered
, because
# we iterate over its keys when adding up the terms of the
# we iterate over its keys when adding up the terms of the
# gradient on ipt. If it is a regular dict, the grad method
# gradient on ipt. If it is a regular dict, the grad method
# will return something that is analytically correct, but
# will return something that is analytically correct, but
# whose order of doing additions depends on the memory
# whose order of doing additions depends on the memory
# location of the apply nodes.
# location of the apply nodes.
var_to_app_to_idx
[
ipt
]
=
OrderedDict
()
var_to_app_to_idx
[
ipt
]
=
{}
app_to_idx
=
var_to_app_to_idx
[
ipt
]
app_to_idx
=
var_to_app_to_idx
[
ipt
]
if
app
not
in
app_to_idx
:
if
app
not
in
app_to_idx
:
app_to_idx
[
app
]
=
[]
app_to_idx
[
app
]
=
[]
...
@@ -1052,7 +1046,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
...
@@ -1052,7 +1046,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
"""
"""
# build a dict mapping node to the terms node contributes to each of
# build a dict mapping node to the terms node contributes to each of
# its inputs' gradients
# its inputs' gradients
term_dict
=
OrderedDict
()
term_dict
=
{}
def
access_term_cache
(
node
):
def
access_term_cache
(
node
):
"""Populates term_dict[node] and returns it"""
"""Populates term_dict[node] and returns it"""
...
@@ -1978,7 +1972,7 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
...
@@ -1978,7 +1972,7 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
if
expression
.
ndim
==
0
:
if
expression
.
ndim
==
0
:
# expression is just a scalar, use grad
# expression is just a scalar, use grad
return
format_as
(
return
as_list_or_tuple
(
using_list
,
using_list
,
using_tuple
,
using_tuple
,
grad
(
grad
(
...
@@ -2013,7 +2007,7 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
...
@@ -2013,7 +2007,7 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
non_sequences
=
[
expression
]
+
wrt
,
non_sequences
=
[
expression
]
+
wrt
,
)
)
assert
not
updates
,
"Scan has returned a list of updates; this should not happen."
assert
not
updates
,
"Scan has returned a list of updates; this should not happen."
return
format_as
(
using_list
,
using_tuple
,
jacobs
)
return
as_list_or_tuple
(
using_list
,
using_tuple
,
jacobs
)
def
hessian
(
cost
,
wrt
,
consider_constant
=
None
,
disconnected_inputs
=
"raise"
):
def
hessian
(
cost
,
wrt
,
consider_constant
=
None
,
disconnected_inputs
=
"raise"
):
...
@@ -2093,7 +2087,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
...
@@ -2093,7 +2087,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
not
updates
not
updates
),
"Scan has returned a list of updates; this should not happen."
),
"Scan has returned a list of updates; this should not happen."
hessians
.
append
(
hess
)
hessians
.
append
(
hess
)
return
format_as
(
using_list
,
using_tuple
,
hessians
)
return
as_list_or_tuple
(
using_list
,
using_tuple
,
hessians
)
def
_is_zero
(
x
):
def
_is_zero
(
x
):
...
@@ -2134,7 +2128,6 @@ class ConsiderConstant(ViewOp):
...
@@ -2134,7 +2128,6 @@ class ConsiderConstant(ViewOp):
consider_constant_
=
ConsiderConstant
()
consider_constant_
=
ConsiderConstant
()
# I create a function only to have the doc show well.
def
consider_constant
(
x
):
def
consider_constant
(
x
):
"""
"""
DEPRECATED: use zero_grad() or disconnected_grad() instead.
DEPRECATED: use zero_grad() or disconnected_grad() instead.
...
...
tests/test_gradient.py
浏览文件 @
ccfe2d3d
...
@@ -278,8 +278,6 @@ class TestGrad:
...
@@ -278,8 +278,6 @@ class TestGrad:
g
=
grad
(
a1
.
outputs
[
0
],
a1
.
outputs
[
1
],
disconnected_inputs
=
"ignore"
)
g
=
grad
(
a1
.
outputs
[
0
],
a1
.
outputs
[
1
],
disconnected_inputs
=
"ignore"
)
assert
g
.
owner
.
op
==
at
.
fill
assert
g
.
owner
.
op
==
at
.
fill
assert
g
.
owner
.
inputs
[
1
]
.
data
==
0
assert
g
.
owner
.
inputs
[
1
]
.
data
==
0
with
pytest
.
raises
(
TypeError
):
grad
(
a1
.
outputs
[
0
],
"wtf"
)
def
test_NNone_rval
(
self
):
def
test_NNone_rval
(
self
):
# grad: Test returning some zero value from grad
# grad: Test returning some zero value from grad
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论