Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ccfe2d3d
提交
ccfe2d3d
authored
7月 05, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
7月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor aesara.gradient and add type hints
上级
1d369b55
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
211 行增加
和
220 行删除
+211
-220
gradient.py
aesara/gradient.py
+211
-218
test_gradient.py
tests/test_gradient.py
+0
-2
没有找到文件。
aesara/gradient.py
浏览文件 @
ccfe2d3d
"""Driver for gradient calculations."""
import
logging
import
time
import
warnings
from
collections
import
OrderedDict
from
functools
import
partial
,
reduce
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
,
Union
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Dict
,
List
,
Mapping
,
MutableSequence
,
Optional
,
Sequence
,
Tuple
,
TypeVar
,
Union
,
)
import
numpy
as
np
from
typing_extensions
import
Literal
import
aesara
from
aesara.compile.ops
import
ViewOp
from
aesara.configdefaults
import
config
from
aesara.graph
import
utils
from
aesara.graph.basic
import
NominalVariable
,
Variable
from
aesara.graph.basic
import
Apply
,
NominalVariable
,
Variable
from
aesara.graph.null_type
import
NullType
,
null_type
from
aesara.graph.op
import
get_test_values
from
aesara.graph.type
import
Type
...
...
@@ -23,26 +34,18 @@ if TYPE_CHECKING:
from
aesara.compile.mode
import
Mode
__docformat__
=
"restructuredtext en"
_logger
=
logging
.
getLogger
(
"aesara.gradient"
)
V
=
TypeVar
(
"V"
,
bound
=
Optional
[
Variable
])
# we can't do "import aesara.tensor"
# tensor depends on aesara.compile
# aesara.compile depends on aesara.gradient (this file)
# the reason aesara.compile depends on aesara.gradient
# is that aesara.compile.builders contains the op from graph
# functionality and it uses aesara.gradient to implement
# the new op's grad method
tensor
=
None
_msg_retType
=
"op.grad(...) returned a non-list"
# TODO: Refactor this so that it's not a global variable
grad_time
:
float
=
0.0
grad_time
=
0
def
format_as
(
use_list
,
use_tuple
,
outputs
):
"""
Formats the outputs according to the flags `use_list` and `use_tuple`.
# TODO: Add `overload` variants
def
as_list_or_tuple
(
use_list
:
bool
,
use_tuple
:
bool
,
outputs
:
Union
[
V
,
Sequence
[
V
]]
)
->
Union
[
V
,
List
[
V
],
Tuple
[
V
,
...
]]:
"""Return either a single object or a list/tuple of objects.
If `use_list` is True, `outputs` is returned as a list (if `outputs`
is not a list or a tuple then it is converted in a one element list).
...
...
@@ -52,22 +55,25 @@ def format_as(use_list, use_tuple, outputs):
"""
if
use_list
and
use_tuple
:
raise
ValueError
(
"Both flags cannot be simultaneously True"
)
if
(
use_list
or
use_tuple
)
and
not
isinstance
(
outputs
,
(
list
,
tuple
)):
if
use_list
:
return
[
outputs
]
else
:
return
(
outputs
,)
elif
not
(
use_list
or
use_tuple
)
and
isinstance
(
outputs
,
(
list
,
tuple
)):
if
len
(
outputs
)
!=
1
:
raise
ValueError
(
"Wrong arguments; expected a one element list"
)
return
outputs
[
0
]
elif
use_list
or
use_tuple
:
if
use_list
:
return
list
(
outputs
)
if
use_list
or
use_tuple
:
if
isinstance
(
outputs
,
Sequence
):
if
use_list
:
return
list
(
outputs
)
else
:
return
tuple
(
outputs
)
else
:
return
tuple
(
outputs
)
if
use_list
:
return
[
outputs
]
else
:
return
(
outputs
,)
else
:
return
outputs
if
isinstance
(
outputs
,
Sequence
):
if
len
(
outputs
)
!=
1
:
raise
ValueError
(
"Wrong arguments; expected a one element list"
)
return
outputs
[
0
]
else
:
return
outputs
def
grad_not_implemented
(
op
,
x_pos
,
x
,
comment
=
""
):
...
...
@@ -155,97 +161,87 @@ class DisconnectedType(Type):
disconnected_type
=
DisconnectedType
()
########################
# R Operator
########################
def
Rop
(
f
:
Union
[
Variable
,
Sequence
[
Variable
]],
wrt
:
Union
[
Variable
,
Sequence
[
Variable
]],
eval_points
:
Union
[
Variable
,
Sequence
[
Variable
]],
disconnected_outputs
:
Literal
[
"ignore"
,
"warn"
,
"raise"
]
=
"raise"
,
return_disconnected
:
Literal
[
"none"
,
"zero"
,
"disconnected"
]
=
"zero"
,
)
->
Union
[
Optional
[
Variable
],
Sequence
[
Optional
[
Variable
]]]:
"""Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`.
def
Rop
(
f
,
wrt
,
eval_points
,
disconnected_outputs
=
"raise"
,
return_disconnected
=
"zero"
):
"""
Computes the R operation on `f` wrt to `wrt` at `eval_points`.
Mathematically this stands for the jacobian of `f` wrt
to `wrt` right muliplied by the eval points.
Mathematically this stands for the Jacobian of `f` right multiplied by the
`eval_points`.
Parameters
----------
f : :class:`~aesara.graph.basic.Variable` or list of Variables
`f` stands for the output of the computational graph to which you
want to apply the R operator
wrt : :class:`~aesara.graph.basic.Variable` or list of Variables
variables for which you compute the R operator of the expression
described by `f`
eval_points : :class:`~aesara.graph.basic.Variable` or list of Variables
evaluation points for each of the variables in `wrt`
disconnected_outputs : str
f
The outputs of the computational graph to which the R-operator is
applied.
wrt
Variables for which the R-operator of `f` is computed.
eval_points
Points at which to evaluate each of the variables in `wrt`.
disconnected_outputs
Defines the behaviour if some of the variables in `f`
have no dependency on any of the variable in `wrt` (or if
all links are non-differentiable). The possible values are:
-
'ignore'
: considers that the gradient on these parameters is zero.
-
'warn'
: consider the gradient zero, and print a warning.
-
'raise': raise DisconnectedInputError
.
-
``'ignore'``
: considers that the gradient on these parameters is zero.
-
``'warn'``
: consider the gradient zero, and print a warning.
-
``'raise'``: raise `DisconnectedInputError`
.
return_disconnected
: {'zero', 'None', 'Disconnected'}
-
'zero' : If wrt[i] is disconnected, return value i
will be
wrt[i].zeros_like()
-
'None' : If wrt[i] is disconnected, return value i
will be
None
-
'Disconnected' : returns variables of type DisconnectedType
return_disconnected
-
``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i``
will be
``wrt[i].zeros_like()``.
-
``'none'`` : If ``wrt[i]`` is disconnected, return value ``i``
will be
``None``
-
``'disconnected'`` : returns variables of type `DisconnectedType`
Returns
-------
:class:`~aesara.graph.basic.Variable` or list/tuple of Variables depending on type of f
Symbolic expression such that
R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]
A symbolic expression such obeying
``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``,
where the indices in that expression are magic multidimensional
indices that specify both the position within a list and all
coordinates of the tensor element
in the last
.
coordinates of the tensor element
s
.
If `wrt` is a list/tuple, then return a list/tuple with the results.
"""
using_list
=
isinstance
(
f
,
list
)
using_tuple
=
isinstance
(
f
,
tuple
)
if
not
isinstance
(
wrt
,
(
list
,
tuple
)):
wrt
=
[
wrt
]
_wrt
:
List
[
Variable
]
=
[
aesara
.
tensor
.
as_tensor_variable
(
wrt
)]
else
:
_wrt
=
[
aesara
.
tensor
.
as_tensor_variable
(
x
)
for
x
in
wrt
]
if
not
isinstance
(
eval_points
,
(
list
,
tuple
)):
eval_points
=
[
eval_points
]
_eval_points
:
List
[
Variable
]
=
[
aesara
.
tensor
.
as_tensor_variable
(
eval_points
)]
else
:
_eval_points
=
[
aesara
.
tensor
.
as_tensor_variable
(
x
)
for
x
in
eval_points
]
if
not
isinstance
(
f
,
(
list
,
tuple
)):
f
=
[
f
]
_f
:
List
[
Variable
]
=
[
aesara
.
tensor
.
as_tensor_variable
(
f
)]
else
:
_f
=
[
aesara
.
tensor
.
as_tensor_variable
(
x
)
for
x
in
f
]
if
len
(
wrt
)
!=
len
(
eval_points
):
if
len
(
_wrt
)
!=
len
(
_
eval_points
):
raise
ValueError
(
"`wrt` must be the same length as `eval_points`."
)
# Check that each element of wrt corresponds to an element
# of eval_points with the same dimensionality.
for
pack
in
enumerate
(
zip
(
wrt
,
eval_points
)):
i
=
pack
[
0
]
wrt_elem
,
eval_point
=
pack
[
1
]
if
not
isinstance
(
wrt_elem
,
Variable
):
wrt_elem
=
aesara
.
tensor
.
as_tensor_variable
(
wrt_elem
)
if
not
isinstance
(
eval_point
,
Variable
):
eval_point
=
aesara
.
tensor
.
as_tensor_variable
(
eval_point
)
for
i
,
(
wrt_elem
,
eval_point
)
in
enumerate
(
zip
(
_wrt
,
_eval_points
)):
try
:
if
wrt_elem
.
type
.
ndim
!=
eval_point
.
type
.
ndim
:
raise
ValueError
(
"Element "
+
str
(
i
)
+
" of wrt/eval_point have mismatched "
+
"dimensionality: "
+
str
(
wrt_elem
.
type
.
ndim
)
+
" versus "
+
str
(
eval_point
.
type
.
ndim
)
f
"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: "
f
"{wrt_elem.type.ndim} and {eval_point.type.ndim}"
)
except
AttributeError
:
# wrt_elem and eval_point don't always have ndim like random type
# Tensor, Sparse have the ndim attribute
pass
seen_nodes
=
OrderedDict
()
seen_nodes
:
Dict
[
Apply
,
Sequence
[
Variable
]]
=
{}
def
_traverse
(
node
):
"""TODO: writeme"""
...
...
@@ -260,8 +256,8 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
# inputs of the node
local_eval_points
=
[]
for
inp
in
inputs
:
if
inp
in
wrt
:
local_eval_points
.
append
(
eval_points
[
wrt
.
index
(
inp
)])
if
inp
in
_
wrt
:
local_eval_points
.
append
(
_eval_points
[
_
wrt
.
index
(
inp
)])
elif
inp
.
owner
is
None
:
try
:
local_eval_points
.
append
(
inp
.
zeros_like
())
...
...
@@ -316,13 +312,13 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
# end _traverse
# Populate the dictionary
for
out
in
f
:
for
out
in
_
f
:
_traverse
(
out
.
owner
)
rval
=
[]
for
out
in
f
:
if
out
in
wrt
:
rval
.
append
(
eval_points
[
wrt
.
index
(
out
)])
rval
:
List
[
Optional
[
Variable
]]
=
[]
for
out
in
_
f
:
if
out
in
_
wrt
:
rval
.
append
(
_eval_points
[
_
wrt
.
index
(
out
)])
elif
(
seen_nodes
.
get
(
out
.
owner
,
None
)
is
None
or
seen_nodes
[
out
.
owner
][
out
.
owner
.
outputs
.
index
(
out
)]
is
None
...
...
@@ -361,81 +357,89 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
else
:
rval
.
append
(
seen_nodes
[
out
.
owner
][
out
.
owner
.
outputs
.
index
(
out
)])
return
format_as
(
using_list
,
using_tuple
,
rval
)
using_list
=
isinstance
(
f
,
list
)
using_tuple
=
isinstance
(
f
,
tuple
)
return
as_list_or_tuple
(
using_list
,
using_tuple
,
rval
)
def
Lop
(
f
,
wrt
,
eval_points
,
consider_constant
=
None
,
disconnected_inputs
=
"raise"
):
"""Computes the L operation on `f` with respect to `wrt` at `eval_points`.
def
Lop
(
f
:
Union
[
Variable
,
Sequence
[
Variable
]],
wrt
:
Union
[
Variable
,
Sequence
[
Variable
]],
eval_points
:
Union
[
Variable
,
Sequence
[
Variable
]],
consider_constant
:
Optional
[
Sequence
[
Variable
]]
=
None
,
disconnected_inputs
:
Literal
[
"ignore"
,
"warn"
,
"raise"
]
=
"raise"
,
)
->
Union
[
Optional
[
Variable
],
Sequence
[
Optional
[
Variable
]]]:
"""Computes the L-operator applied to `f` with respect to `wrt` at `eval_points`.
Mathematically this stands for the Jacobian of `f` with respect to `wrt`
left muliplied by the `eval_points`.
Parameters
----------
f : :class:`~aesara.graph.basic.Variable` or list of Variables
`f` stands for the output of the computational graph to which you
want to apply the L operator
wrt : :class:`~aesara.graph.basic.Variable` or list of Variables
variables for which you compute the L operator of the expression
described by `f`
eval_points : :class:`~aesara.graph.basic.Variable` or list of Variables
evaluation points for each of the variables in `f`
f
The outputs of the computational graph to which the R-operator is
applied.
wrt
Variables for which the R-operator of `f` is computed.
eval_points
Points at which to evaluate each of the variables in `wrt`.
consider_constant
See `grad`.
disconnected_inputs
See `grad`.
Returns
-------
:class:`~aesara.graph.basic.Variable` or list/tuple of Variables depending on type of `f`
Symbolic expression such that
A symbolic expression satisfying
``L_op[i] = sum_i (d f[i] / d wrt[j]) eval_point[i]``
where the indices in that expression are magic multidimensional
indices that specify both the position within a list and all
coordinates of the tensor element
in the last
coordinates of the tensor element
s.
If `f` is a list/tuple, then return a list/tuple with the results.
"""
if
not
isinstance
(
eval_points
,
(
list
,
tuple
)):
eval_points
=
[
eval_points
]
using_list
=
isinstance
(
wrt
,
list
)
using_tuple
=
isinstance
(
wrt
,
tuple
)
_eval_points
:
List
[
Variable
]
=
[
aesara
.
tensor
.
as_tensor_variable
(
eval_points
)]
else
:
_eval_points
=
[
aesara
.
tensor
.
as_tensor_variable
(
x
)
for
x
in
eval_points
]
if
not
isinstance
(
f
,
(
list
,
tuple
)):
f
=
[
f
]
_f
:
List
[
Variable
]
=
[
aesara
.
tensor
.
as_tensor_variable
(
f
)]
else
:
_f
=
[
aesara
.
tensor
.
as_tensor_variable
(
x
)
for
x
in
f
]
# make copies of f and grads so we don't modify the client's copy
f
=
list
(
f
)
grads
=
list
(
eval_points
)
grads
=
list
(
_eval_points
)
if
not
isinstance
(
wrt
,
(
list
,
tuple
)):
wrt
=
[
wrt
]
_wrt
:
List
[
Variable
]
=
[
aesara
.
tensor
.
as_tensor_variable
(
wrt
)]
else
:
_wrt
=
[
aesara
.
tensor
.
as_tensor_variable
(
x
)
for
x
in
wrt
]
assert
len
(
f
)
==
len
(
grads
)
known
=
OrderedDict
(
zip
(
f
,
grads
))
assert
len
(
_
f
)
==
len
(
grads
)
known
=
dict
(
zip
(
_
f
,
grads
))
ret
=
grad
(
cost
=
None
,
known_grads
=
known
,
consider_constant
=
consider_constant
,
wrt
=
wrt
,
wrt
=
_
wrt
,
disconnected_inputs
=
disconnected_inputs
,
)
return
format_as
(
using_list
,
using_tuple
,
ret
)
#########################
# Gradient
#########################
using_list
=
isinstance
(
wrt
,
list
)
using_tuple
=
isinstance
(
wrt
,
tuple
)
return
as_list_or_tuple
(
using_list
,
using_tuple
,
ret
)
def
grad
(
cost
,
wrt
,
consider_constant
=
None
,
disconnected_inputs
=
"raise"
,
add_names
=
True
,
known_grads
=
None
,
return_disconnected
=
"zero"
,
null_gradients
=
"raise"
,
):
cost
:
Optional
[
Variable
]
,
wrt
:
Union
[
Variable
,
Sequence
[
Variable
]]
,
consider_constant
:
Optional
[
Sequence
[
Variable
]]
=
None
,
disconnected_inputs
:
Literal
[
"ignore"
,
"warn"
,
"raise"
]
=
"raise"
,
add_names
:
bool
=
True
,
known_grads
:
Optional
[
Mapping
[
Variable
,
Variable
]]
=
None
,
return_disconnected
:
Literal
[
"none"
,
"zero"
,
"disconnected"
]
=
"zero"
,
null_gradients
:
Literal
[
"raise"
,
"return"
]
=
"raise"
,
)
->
Union
[
Optional
[
Variable
],
Sequence
[
Optional
[
Variable
]]]
:
"""
Return symbolic gradients of one cost with respect to one or more variables.
...
...
@@ -445,49 +449,47 @@ def grad(
Parameters
----------
cost
: :class:`~aesara.graph.basic.Variable` scalar (0-dimensional) tensor variable or ``None``
Value that we are differentiating (
that we want the gradient of).
May be `None` if `known_grads` is provided.
wrt
: :class:`~aesara.graph.basic.Variable` or list of Variables
T
erm[s] with respect to which we want gradients
consider_constant
: list of variables
Expressions not to backpropagate through
cost
Value that we are differentiating (
i.e. for which we want the
gradient).
May be `None` if `known_grads` is provided.
wrt
T
he term(s) with respect to which we want gradients.
consider_constant
Expressions not to backpropagate through
.
disconnected_inputs : {'ignore', 'warn', 'raise'}
Defines the behaviour if some of the variables in `wrt` are
not part of the computational graph computing `cost` (or if
all links are non-differentiable). The possible values are:
-
'ignore': considers that the gradient on these parameters is zero.
-
'warn': consider the gradient zero, and print a warning.
-
'raise': raise DisconnectedInputError.
add_names
: bool
If
True, variables generated by grad
will be named
(d<cost.name>/d<wrt.name>) provided that both cost and wrt
have names
known_grads
: OrderedDict, optional
A ordered dictionary mapping variables to their gradients. This is
useful in the case where you know the gradient
on
some
-
``'ignore'``: considers that the gradient on these parameters is zero
-
``'warn'``: consider the gradient zero, and print a warning
-
``'raise'``: raise `DisconnectedInputError`
add_names
If
``True``, variables generated by `grad`
will be named
``(d<cost.name>/d<wrt.name>)`` provided that both `cost` and `wrt`
have names
.
known_grads
A
n
ordered dictionary mapping variables to their gradients. This is
useful in the case where you know the gradient
s of
some
variables but do not know the original cost.
return_disconnected
: {'zero', 'None', 'Disconnected'}
-
'zero' : If wrt[i] is disconnected, return value i
will be
wrt[i].zeros_like()
-
'None' : If wrt[i] is disconnected, return value i
will be
None
-
'Disconnected' : returns variables of type DisconnectedType
null_gradients
: {'raise', 'return'}
Defines the behaviour
if
some of the variables in `wrt` have a
return_disconnected
-
``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i``
will be
``wrt[i].zeros_like()``
-
``'none'`` : If ``wrt[i]`` is disconnected, return value ``i``
will be
``None``
-
``'disconnected'`` : returns variables of type `DisconnectedType`
null_gradients
Defines the behaviour
when
some of the variables in `wrt` have a
null gradient. The possibles values are:
-
'raise' : raise a NullTypeGradError
exception
-
'return'
: return the null gradients
-
``'raise'`` : raise a `NullTypeGradError`
exception
-
``'return'``
: return the null gradients
Returns
-------
variable or list/tuple of variables (matches `wrt`)
Symbolic expression of gradient of `cost` with respect to each
of the `wrt` terms. If an element of `wrt` is not
differentiable with respect to the output, then a zero
variable is returned.
A symbolic expression for the gradient of `cost` with respect to each
of the `wrt` terms. If an element of `wrt` is not differentiable with
respect to the output, then a zero variable is returned.
"""
t0
=
time
.
time
()
...
...
@@ -498,30 +500,17 @@ def grad(
if
cost
is
not
None
and
isinstance
(
cost
.
type
,
NullType
):
raise
ValueError
(
"Can't differentiate a NaN cost."
"cost is NaN because "
+
cost
.
type
.
why_null
)
if
cost
is
not
None
and
cost
.
ndim
!=
0
:
raise
TypeError
(
"cost must be a scalar."
)
if
isinstance
(
wrt
,
set
):
raise
TypeError
(
"wrt must not be a set. sets have no defined "
"iteration order, so we can't return gradients in a"
" matching order."
"Can't differentiate a NaN cost. "
f
"Cost is NaN because {cost.type.why_null}"
)
using_list
=
isinstance
(
wrt
,
list
)
using_tuple
=
isinstance
(
wrt
,
tuple
)
if
not
using_list
and
not
using_tuple
:
wrt
=
[
wrt
]
if
cost
is
not
None
and
cost
.
type
.
ndim
!=
0
:
raise
TypeError
(
"Cost must be a scalar."
)
for
elem
in
wrt
:
if
not
isinstance
(
elem
,
Variable
):
raise
TypeError
(
"Expected Variable, got "
+
str
(
elem
)
+
" of type "
+
str
(
type
(
elem
))
)
if
not
isinstance
(
wrt
,
Sequence
):
_wrt
:
List
[
Variable
]
=
[
wrt
]
else
:
_wrt
=
[
x
for
x
in
wrt
]
outputs
=
[]
if
cost
is
not
None
:
...
...
@@ -529,16 +518,15 @@ def grad(
if
known_grads
is
not
None
:
outputs
.
extend
(
list
(
known_grads
.
keys
()))
var_to_app_to_idx
=
_populate_var_to_app_to_idx
(
outputs
,
wrt
,
consider_constant
)
var_to_app_to_idx
=
_populate_var_to_app_to_idx
(
outputs
,
_
wrt
,
consider_constant
)
# build a dict mapping var to the gradient of cost with respect to var
grad_dict
=
OrderedDict
()
grad_dict
=
{}
if
known_grads
is
None
:
known_grads
=
OrderedDict
()
else
:
m
=
"known_grads must be an OrderedDict. "
assert
isinstance
(
known_grads
,
OrderedDict
)
or
len
(
known_grads
)
<=
1
,
m
known_grads
=
{}
assert
isinstance
(
known_grads
,
dict
)
# The gradient of the cost is 1 unless specified otherwise by known_grads.
if
cost
is
not
None
:
...
...
@@ -615,7 +603,7 @@ def grad(
# if wrt is such a variable, populate the grad_dict with this info
# so that wrt not being in var_to_app_to_idx won't cause an error below
# according to the flag, possibly raise an error if wrt is disconnected
for
elem
in
wrt
:
for
elem
in
_
wrt
:
if
elem
not
in
var_to_app_to_idx
and
elem
is
not
cost
and
elem
not
in
grad_dict
:
handle_disconnected
(
elem
)
grad_dict
[
elem
]
=
disconnected_type
()
...
...
@@ -632,32 +620,38 @@ def grad(
if
hasattr
(
g
.
type
,
"dtype"
):
assert
g
.
type
.
dtype
in
aesara
.
tensor
.
type
.
float_dtypes
rval
=
_populate_grad_dict
(
var_to_app_to_idx
,
grad_dict
,
wrt
,
cost_name
)
_rval
:
Sequence
[
Variable
]
=
_populate_grad_dict
(
var_to_app_to_idx
,
grad_dict
,
_wrt
,
cost_name
)
rval
:
MutableSequence
[
Optional
[
Variable
]]
=
list
(
_rval
)
for
i
in
range
(
len
(
rval
)):
if
isinstance
(
rval
[
i
]
.
type
,
NullType
):
for
i
in
range
(
len
(
_
rval
)):
if
isinstance
(
_
rval
[
i
]
.
type
,
NullType
):
if
null_gradients
==
"raise"
:
raise
NullTypeGradError
(
f
"
grad encountered a NaN. {
rval[i].type.why_null}"
f
"
`grad` encountered a NaN. {_
rval[i].type.why_null}"
)
else
:
assert
null_gradients
==
"return"
if
isinstance
(
rval
[
i
]
.
type
,
DisconnectedType
):
handle_disconnected
(
rval
[
i
])
if
isinstance
(
_
rval
[
i
]
.
type
,
DisconnectedType
):
handle_disconnected
(
_
rval
[
i
])
if
return_disconnected
==
"zero"
:
rval
[
i
]
=
_float_zeros_like
(
wrt
[
i
])
elif
return_disconnected
==
"N
one"
:
rval
[
i
]
=
_float_zeros_like
(
_
wrt
[
i
])
elif
return_disconnected
.
lower
()
==
"n
one"
:
rval
[
i
]
=
None
else
:
assert
return_disconnected
==
"D
isconnected"
assert
return_disconnected
.
lower
()
==
"d
isconnected"
if
using_tuple
:
rval
=
tuple
(
rval
)
elif
not
using_list
:
(
rval
,)
=
rval
t1
=
time
.
time
()
global
grad_time
grad_time
+=
t1
-
t0
if
isinstance
(
wrt
,
tuple
):
return
tuple
(
rval
)
elif
not
isinstance
(
wrt
,
list
):
return
rval
[
0
]
return
rval
...
...
@@ -801,7 +795,7 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
for
i
in
range
(
len
(
grads
)):
grads
[
i
]
+=
cost_grads
[
i
]
pgrads
=
OrderedD
ict
(
zip
(
params
,
grads
))
pgrads
=
d
ict
(
zip
(
params
,
grads
))
# separate wrt from end grads:
wrt_grads
=
list
(
pgrads
[
k
]
for
k
in
wrt
)
end_grads
=
list
(
pgrads
[
k
]
for
k
in
end
)
...
...
@@ -916,7 +910,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
# var_to_app_to_idx[var][node] = [i,j] means node has
# var as input at positions i and j
var_to_app_to_idx
=
OrderedD
ict
()
var_to_app_to_idx
=
d
ict
()
# Set of variables that have been added to their true parents
# ('true' here means that the elements of the variable are a function
...
...
@@ -954,13 +948,13 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
continue
if
ipt
not
in
var_to_app_to_idx
:
# This object here *must* be
an OrderedDict
, because
# This object here *must* be
ordered
, because
# we iterate over its keys when adding up the terms of the
# gradient on ipt. If it is a regular dict, the grad method
# will return something that is analytically correct, but
# whose order of doing additions depends on the memory
# location of the apply nodes.
var_to_app_to_idx
[
ipt
]
=
OrderedDict
()
var_to_app_to_idx
[
ipt
]
=
{}
app_to_idx
=
var_to_app_to_idx
[
ipt
]
if
app
not
in
app_to_idx
:
app_to_idx
[
app
]
=
[]
...
...
@@ -1052,7 +1046,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
"""
# build a dict mapping node to the terms node contributes to each of
# its inputs' gradients
term_dict
=
OrderedDict
()
term_dict
=
{}
def
access_term_cache
(
node
):
"""Populates term_dict[node] and returns it"""
...
...
@@ -1978,7 +1972,7 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
if
expression
.
ndim
==
0
:
# expression is just a scalar, use grad
return
format_as
(
return
as_list_or_tuple
(
using_list
,
using_tuple
,
grad
(
...
...
@@ -2013,7 +2007,7 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
non_sequences
=
[
expression
]
+
wrt
,
)
assert
not
updates
,
"Scan has returned a list of updates; this should not happen."
return
format_as
(
using_list
,
using_tuple
,
jacobs
)
return
as_list_or_tuple
(
using_list
,
using_tuple
,
jacobs
)
def
hessian
(
cost
,
wrt
,
consider_constant
=
None
,
disconnected_inputs
=
"raise"
):
...
...
@@ -2093,7 +2087,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
not
updates
),
"Scan has returned a list of updates; this should not happen."
hessians
.
append
(
hess
)
return
format_as
(
using_list
,
using_tuple
,
hessians
)
return
as_list_or_tuple
(
using_list
,
using_tuple
,
hessians
)
def
_is_zero
(
x
):
...
...
@@ -2134,7 +2128,6 @@ class ConsiderConstant(ViewOp):
consider_constant_
=
ConsiderConstant
()
# I create a function only to have the doc show well.
def
consider_constant
(
x
):
"""
DEPRECATED: use zero_grad() or disconnected_grad() instead.
...
...
tests/test_gradient.py
浏览文件 @
ccfe2d3d
...
...
@@ -278,8 +278,6 @@ class TestGrad:
g
=
grad
(
a1
.
outputs
[
0
],
a1
.
outputs
[
1
],
disconnected_inputs
=
"ignore"
)
assert
g
.
owner
.
op
==
at
.
fill
assert
g
.
owner
.
inputs
[
1
]
.
data
==
0
with
pytest
.
raises
(
TypeError
):
grad
(
a1
.
outputs
[
0
],
"wtf"
)
def
test_NNone_rval
(
self
):
# grad: Test returning some zero value from grad
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论