提交 ccfe2d3d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor aesara.gradient and add type hints

上级 1d369b55
"""Driver for gradient calculations."""
import logging
import time
import warnings
from collections import OrderedDict
from functools import partial, reduce
from typing import TYPE_CHECKING, Callable, List, Optional, Union
from typing import (
TYPE_CHECKING,
Callable,
Dict,
List,
Mapping,
MutableSequence,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
import numpy as np
from typing_extensions import Literal
import aesara
from aesara.compile.ops import ViewOp
from aesara.configdefaults import config
from aesara.graph import utils
from aesara.graph.basic import NominalVariable, Variable
from aesara.graph.basic import Apply, NominalVariable, Variable
from aesara.graph.null_type import NullType, null_type
from aesara.graph.op import get_test_values
from aesara.graph.type import Type
......@@ -23,26 +34,18 @@ if TYPE_CHECKING:
from aesara.compile.mode import Mode
__docformat__ = "restructuredtext en"
_logger = logging.getLogger("aesara.gradient")
V = TypeVar("V", bound=Optional[Variable])
# we can't do "import aesara.tensor"
# tensor depends on aesara.compile
# aesara.compile depends on aesara.gradient (this file)
# the reason aesara.compile depends on aesara.gradient
# is that aesara.compile.builders contains the op from graph
# functionality and it uses aesara.gradient to implement
# the new op's grad method
tensor = None
_msg_retType = "op.grad(...) returned a non-list"
# TODO: Refactor this so that it's not a global variable
grad_time: float = 0.0
grad_time = 0
def format_as(use_list, use_tuple, outputs):
"""
Formats the outputs according to the flags `use_list` and `use_tuple`.
# TODO: Add `overload` variants
def as_list_or_tuple(
use_list: bool, use_tuple: bool, outputs: Union[V, Sequence[V]]
) -> Union[V, List[V], Tuple[V, ...]]:
"""Return either a single object or a list/tuple of objects.
If `use_list` is True, `outputs` is returned as a list (if `outputs`
is not a list or a tuple then it is converted in a one element list).
......@@ -52,22 +55,25 @@ def format_as(use_list, use_tuple, outputs):
"""
if use_list and use_tuple:
raise ValueError("Both flags cannot be simultaneously True")
if (use_list or use_tuple) and not isinstance(outputs, (list, tuple)):
if use_list:
return [outputs]
else:
return (outputs,)
elif not (use_list or use_tuple) and isinstance(outputs, (list, tuple)):
if len(outputs) != 1:
raise ValueError("Wrong arguments; expected a one element list")
return outputs[0]
elif use_list or use_tuple:
if use_list:
return list(outputs)
if use_list or use_tuple:
if isinstance(outputs, Sequence):
if use_list:
return list(outputs)
else:
return tuple(outputs)
else:
return tuple(outputs)
if use_list:
return [outputs]
else:
return (outputs,)
else:
return outputs
if isinstance(outputs, Sequence):
if len(outputs) != 1:
raise ValueError("Wrong arguments; expected a one element list")
return outputs[0]
else:
return outputs
def grad_not_implemented(op, x_pos, x, comment=""):
......@@ -155,97 +161,87 @@ class DisconnectedType(Type):
disconnected_type = DisconnectedType()
########################
# R Operator
########################
def Rop(
f: Union[Variable, Sequence[Variable]],
wrt: Union[Variable, Sequence[Variable]],
eval_points: Union[Variable, Sequence[Variable]],
disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise",
return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
) -> Union[Optional[Variable], Sequence[Optional[Variable]]]:
"""Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`.
def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="zero"):
"""
Computes the R operation on `f` wrt to `wrt` at `eval_points`.
Mathematically this stands for the jacobian of `f` wrt
to `wrt` right muliplied by the eval points.
Mathematically this stands for the Jacobian of `f` right multiplied by the
`eval_points`.
Parameters
----------
f : :class:`~aesara.graph.basic.Variable` or list of Variables
`f` stands for the output of the computational graph to which you
want to apply the R operator
wrt : :class:`~aesara.graph.basic.Variable` or list of Variables
variables for which you compute the R operator of the expression
described by `f`
eval_points : :class:`~aesara.graph.basic.Variable` or list of Variables
evaluation points for each of the variables in `wrt`
disconnected_outputs : str
f
The outputs of the computational graph to which the R-operator is
applied.
wrt
Variables for which the R-operator of `f` is computed.
eval_points
Points at which to evaluate each of the variables in `wrt`.
disconnected_outputs
Defines the behaviour if some of the variables in `f`
have no dependency on any of the variable in `wrt` (or if
all links are non-differentiable). The possible values are:
- 'ignore': considers that the gradient on these parameters is zero.
- 'warn': consider the gradient zero, and print a warning.
- 'raise': raise DisconnectedInputError.
- ``'ignore'``: considers that the gradient on these parameters is zero.
- ``'warn'``: consider the gradient zero, and print a warning.
- ``'raise'``: raise `DisconnectedInputError`.
return_disconnected : {'zero', 'None', 'Disconnected'}
- 'zero' : If wrt[i] is disconnected, return value i will be
wrt[i].zeros_like()
- 'None' : If wrt[i] is disconnected, return value i will be
None
- 'Disconnected' : returns variables of type DisconnectedType
return_disconnected
- ``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be
``wrt[i].zeros_like()``.
- ``'none'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be
``None``
- ``'disconnected'`` : returns variables of type `DisconnectedType`
Returns
-------
:class:`~aesara.graph.basic.Variable` or list/tuple of Variables depending on type of f
Symbolic expression such that
R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]
A symbolic expression such obeying
``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``,
where the indices in that expression are magic multidimensional
indices that specify both the position within a list and all
coordinates of the tensor element in the last.
coordinates of the tensor elements.
If `wrt` is a list/tuple, then return a list/tuple with the results.
"""
using_list = isinstance(f, list)
using_tuple = isinstance(f, tuple)
if not isinstance(wrt, (list, tuple)):
wrt = [wrt]
_wrt: List[Variable] = [aesara.tensor.as_tensor_variable(wrt)]
else:
_wrt = [aesara.tensor.as_tensor_variable(x) for x in wrt]
if not isinstance(eval_points, (list, tuple)):
eval_points = [eval_points]
_eval_points: List[Variable] = [aesara.tensor.as_tensor_variable(eval_points)]
else:
_eval_points = [aesara.tensor.as_tensor_variable(x) for x in eval_points]
if not isinstance(f, (list, tuple)):
f = [f]
_f: List[Variable] = [aesara.tensor.as_tensor_variable(f)]
else:
_f = [aesara.tensor.as_tensor_variable(x) for x in f]
if len(wrt) != len(eval_points):
if len(_wrt) != len(_eval_points):
raise ValueError("`wrt` must be the same length as `eval_points`.")
# Check that each element of wrt corresponds to an element
# of eval_points with the same dimensionality.
for pack in enumerate(zip(wrt, eval_points)):
i = pack[0]
wrt_elem, eval_point = pack[1]
if not isinstance(wrt_elem, Variable):
wrt_elem = aesara.tensor.as_tensor_variable(wrt_elem)
if not isinstance(eval_point, Variable):
eval_point = aesara.tensor.as_tensor_variable(eval_point)
for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points)):
try:
if wrt_elem.type.ndim != eval_point.type.ndim:
raise ValueError(
"Element "
+ str(i)
+ " of wrt/eval_point have mismatched "
+ "dimensionality: "
+ str(wrt_elem.type.ndim)
+ " versus "
+ str(eval_point.type.ndim)
f"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: "
f"{wrt_elem.type.ndim} and {eval_point.type.ndim}"
)
except AttributeError:
# wrt_elem and eval_point don't always have ndim like random type
# Tensor, Sparse have the ndim attribute
pass
seen_nodes = OrderedDict()
seen_nodes: Dict[Apply, Sequence[Variable]] = {}
def _traverse(node):
"""TODO: writeme"""
......@@ -260,8 +256,8 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
# inputs of the node
local_eval_points = []
for inp in inputs:
if inp in wrt:
local_eval_points.append(eval_points[wrt.index(inp)])
if inp in _wrt:
local_eval_points.append(_eval_points[_wrt.index(inp)])
elif inp.owner is None:
try:
local_eval_points.append(inp.zeros_like())
......@@ -316,13 +312,13 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
# end _traverse
# Populate the dictionary
for out in f:
for out in _f:
_traverse(out.owner)
rval = []
for out in f:
if out in wrt:
rval.append(eval_points[wrt.index(out)])
rval: List[Optional[Variable]] = []
for out in _f:
if out in _wrt:
rval.append(_eval_points[_wrt.index(out)])
elif (
seen_nodes.get(out.owner, None) is None
or seen_nodes[out.owner][out.owner.outputs.index(out)] is None
......@@ -361,81 +357,89 @@ def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="
else:
rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)])
return format_as(using_list, using_tuple, rval)
using_list = isinstance(f, list)
using_tuple = isinstance(f, tuple)
return as_list_or_tuple(using_list, using_tuple, rval)
def Lop(f, wrt, eval_points, consider_constant=None, disconnected_inputs="raise"):
"""Computes the L operation on `f` with respect to `wrt` at `eval_points`.
def Lop(
f: Union[Variable, Sequence[Variable]],
wrt: Union[Variable, Sequence[Variable]],
eval_points: Union[Variable, Sequence[Variable]],
consider_constant: Optional[Sequence[Variable]] = None,
disconnected_inputs: Literal["ignore", "warn", "raise"] = "raise",
) -> Union[Optional[Variable], Sequence[Optional[Variable]]]:
"""Computes the L-operator applied to `f` with respect to `wrt` at `eval_points`.
Mathematically this stands for the Jacobian of `f` with respect to `wrt`
left muliplied by the `eval_points`.
Parameters
----------
f : :class:`~aesara.graph.basic.Variable` or list of Variables
`f` stands for the output of the computational graph to which you
want to apply the L operator
wrt : :class:`~aesara.graph.basic.Variable` or list of Variables
variables for which you compute the L operator of the expression
described by `f`
eval_points : :class:`~aesara.graph.basic.Variable` or list of Variables
evaluation points for each of the variables in `f`
f
The outputs of the computational graph to which the R-operator is
applied.
wrt
Variables for which the R-operator of `f` is computed.
eval_points
Points at which to evaluate each of the variables in `wrt`.
consider_constant
See `grad`.
disconnected_inputs
See `grad`.
Returns
-------
:class:`~aesara.graph.basic.Variable` or list/tuple of Variables depending on type of `f`
Symbolic expression such that
A symbolic expression satisfying
``L_op[i] = sum_i (d f[i] / d wrt[j]) eval_point[i]``
where the indices in that expression are magic multidimensional
indices that specify both the position within a list and all
coordinates of the tensor element in the last
coordinates of the tensor elements.
If `f` is a list/tuple, then return a list/tuple with the results.
"""
if not isinstance(eval_points, (list, tuple)):
eval_points = [eval_points]
using_list = isinstance(wrt, list)
using_tuple = isinstance(wrt, tuple)
_eval_points: List[Variable] = [aesara.tensor.as_tensor_variable(eval_points)]
else:
_eval_points = [aesara.tensor.as_tensor_variable(x) for x in eval_points]
if not isinstance(f, (list, tuple)):
f = [f]
_f: List[Variable] = [aesara.tensor.as_tensor_variable(f)]
else:
_f = [aesara.tensor.as_tensor_variable(x) for x in f]
# make copies of f and grads so we don't modify the client's copy
f = list(f)
grads = list(eval_points)
grads = list(_eval_points)
if not isinstance(wrt, (list, tuple)):
wrt = [wrt]
_wrt: List[Variable] = [aesara.tensor.as_tensor_variable(wrt)]
else:
_wrt = [aesara.tensor.as_tensor_variable(x) for x in wrt]
assert len(f) == len(grads)
known = OrderedDict(zip(f, grads))
assert len(_f) == len(grads)
known = dict(zip(_f, grads))
ret = grad(
cost=None,
known_grads=known,
consider_constant=consider_constant,
wrt=wrt,
wrt=_wrt,
disconnected_inputs=disconnected_inputs,
)
return format_as(using_list, using_tuple, ret)
#########################
# Gradient
#########################
using_list = isinstance(wrt, list)
using_tuple = isinstance(wrt, tuple)
return as_list_or_tuple(using_list, using_tuple, ret)
def grad(
cost,
wrt,
consider_constant=None,
disconnected_inputs="raise",
add_names=True,
known_grads=None,
return_disconnected="zero",
null_gradients="raise",
):
cost: Optional[Variable],
wrt: Union[Variable, Sequence[Variable]],
consider_constant: Optional[Sequence[Variable]] = None,
disconnected_inputs: Literal["ignore", "warn", "raise"] = "raise",
add_names: bool = True,
known_grads: Optional[Mapping[Variable, Variable]] = None,
return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
null_gradients: Literal["raise", "return"] = "raise",
) -> Union[Optional[Variable], Sequence[Optional[Variable]]]:
"""
Return symbolic gradients of one cost with respect to one or more variables.
......@@ -445,49 +449,47 @@ def grad(
Parameters
----------
cost : :class:`~aesara.graph.basic.Variable` scalar (0-dimensional) tensor variable or ``None``
Value that we are differentiating (that we want the gradient of).
May be `None` if `known_grads` is provided.
wrt : :class:`~aesara.graph.basic.Variable` or list of Variables
Term[s] with respect to which we want gradients
consider_constant : list of variables
Expressions not to backpropagate through
cost
Value that we are differentiating (i.e. for which we want the
gradient). May be `None` if `known_grads` is provided.
wrt
The term(s) with respect to which we want gradients.
consider_constant
Expressions not to backpropagate through.
disconnected_inputs : {'ignore', 'warn', 'raise'}
Defines the behaviour if some of the variables in `wrt` are
not part of the computational graph computing `cost` (or if
all links are non-differentiable). The possible values are:
- 'ignore': considers that the gradient on these parameters is zero.
- 'warn': consider the gradient zero, and print a warning.
- 'raise': raise DisconnectedInputError.
add_names : bool
If True, variables generated by grad will be named
(d<cost.name>/d<wrt.name>) provided that both cost and wrt
have names
known_grads : OrderedDict, optional
A ordered dictionary mapping variables to their gradients. This is
useful in the case where you know the gradient on some
- ``'ignore'``: considers that the gradient on these parameters is zero
- ``'warn'``: consider the gradient zero, and print a warning
- ``'raise'``: raise `DisconnectedInputError`
add_names
If ``True``, variables generated by `grad` will be named
``(d<cost.name>/d<wrt.name>)`` provided that both `cost` and `wrt`
have names.
known_grads
An ordered dictionary mapping variables to their gradients. This is
useful in the case where you know the gradients of some
variables but do not know the original cost.
return_disconnected : {'zero', 'None', 'Disconnected'}
- 'zero' : If wrt[i] is disconnected, return value i will be
wrt[i].zeros_like()
- 'None' : If wrt[i] is disconnected, return value i will be
None
- 'Disconnected' : returns variables of type DisconnectedType
null_gradients : {'raise', 'return'}
Defines the behaviour if some of the variables in `wrt` have a
return_disconnected
- ``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be
``wrt[i].zeros_like()``
- ``'none'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be
``None``
- ``'disconnected'`` : returns variables of type `DisconnectedType`
null_gradients
Defines the behaviour when some of the variables in `wrt` have a
null gradient. The possibles values are:
- 'raise' : raise a NullTypeGradError exception
- 'return' : return the null gradients
- ``'raise'`` : raise a `NullTypeGradError` exception
- ``'return'`` : return the null gradients
Returns
-------
variable or list/tuple of variables (matches `wrt`)
Symbolic expression of gradient of `cost` with respect to each
of the `wrt` terms. If an element of `wrt` is not
differentiable with respect to the output, then a zero
variable is returned.
A symbolic expression for the gradient of `cost` with respect to each
of the `wrt` terms. If an element of `wrt` is not differentiable with
respect to the output, then a zero variable is returned.
"""
t0 = time.time()
......@@ -498,30 +500,17 @@ def grad(
if cost is not None and isinstance(cost.type, NullType):
raise ValueError(
"Can't differentiate a NaN cost."
"cost is NaN because " + cost.type.why_null
)
if cost is not None and cost.ndim != 0:
raise TypeError("cost must be a scalar.")
if isinstance(wrt, set):
raise TypeError(
"wrt must not be a set. sets have no defined "
"iteration order, so we can't return gradients in a"
" matching order."
"Can't differentiate a NaN cost. "
f"Cost is NaN because {cost.type.why_null}"
)
using_list = isinstance(wrt, list)
using_tuple = isinstance(wrt, tuple)
if not using_list and not using_tuple:
wrt = [wrt]
if cost is not None and cost.type.ndim != 0:
raise TypeError("Cost must be a scalar.")
for elem in wrt:
if not isinstance(elem, Variable):
raise TypeError(
"Expected Variable, got " + str(elem) + " of type " + str(type(elem))
)
if not isinstance(wrt, Sequence):
_wrt: List[Variable] = [wrt]
else:
_wrt = [x for x in wrt]
outputs = []
if cost is not None:
......@@ -529,16 +518,15 @@ def grad(
if known_grads is not None:
outputs.extend(list(known_grads.keys()))
var_to_app_to_idx = _populate_var_to_app_to_idx(outputs, wrt, consider_constant)
var_to_app_to_idx = _populate_var_to_app_to_idx(outputs, _wrt, consider_constant)
# build a dict mapping var to the gradient of cost with respect to var
grad_dict = OrderedDict()
grad_dict = {}
if known_grads is None:
known_grads = OrderedDict()
else:
m = "known_grads must be an OrderedDict. "
assert isinstance(known_grads, OrderedDict) or len(known_grads) <= 1, m
known_grads = {}
assert isinstance(known_grads, dict)
# The gradient of the cost is 1 unless specified otherwise by known_grads.
if cost is not None:
......@@ -615,7 +603,7 @@ def grad(
# if wrt is such a variable, populate the grad_dict with this info
# so that wrt not being in var_to_app_to_idx won't cause an error below
# according to the flag, possibly raise an error if wrt is disconnected
for elem in wrt:
for elem in _wrt:
if elem not in var_to_app_to_idx and elem is not cost and elem not in grad_dict:
handle_disconnected(elem)
grad_dict[elem] = disconnected_type()
......@@ -632,32 +620,38 @@ def grad(
if hasattr(g.type, "dtype"):
assert g.type.dtype in aesara.tensor.type.float_dtypes
rval = _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
_rval: Sequence[Variable] = _populate_grad_dict(
var_to_app_to_idx, grad_dict, _wrt, cost_name
)
rval: MutableSequence[Optional[Variable]] = list(_rval)
for i in range(len(rval)):
if isinstance(rval[i].type, NullType):
for i in range(len(_rval)):
if isinstance(_rval[i].type, NullType):
if null_gradients == "raise":
raise NullTypeGradError(
f"grad encountered a NaN. {rval[i].type.why_null}"
f"`grad` encountered a NaN. {_rval[i].type.why_null}"
)
else:
assert null_gradients == "return"
if isinstance(rval[i].type, DisconnectedType):
handle_disconnected(rval[i])
if isinstance(_rval[i].type, DisconnectedType):
handle_disconnected(_rval[i])
if return_disconnected == "zero":
rval[i] = _float_zeros_like(wrt[i])
elif return_disconnected == "None":
rval[i] = _float_zeros_like(_wrt[i])
elif return_disconnected.lower() == "none":
rval[i] = None
else:
assert return_disconnected == "Disconnected"
assert return_disconnected.lower() == "disconnected"
if using_tuple:
rval = tuple(rval)
elif not using_list:
(rval,) = rval
t1 = time.time()
global grad_time
grad_time += t1 - t0
if isinstance(wrt, tuple):
return tuple(rval)
elif not isinstance(wrt, list):
return rval[0]
return rval
......@@ -801,7 +795,7 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
for i in range(len(grads)):
grads[i] += cost_grads[i]
pgrads = OrderedDict(zip(params, grads))
pgrads = dict(zip(params, grads))
# separate wrt from end grads:
wrt_grads = list(pgrads[k] for k in wrt)
end_grads = list(pgrads[k] for k in end)
......@@ -916,7 +910,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
# var_to_app_to_idx[var][node] = [i,j] means node has
# var as input at positions i and j
var_to_app_to_idx = OrderedDict()
var_to_app_to_idx = dict()
# Set of variables that have been added to their true parents
# ('true' here means that the elements of the variable are a function
......@@ -954,13 +948,13 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
continue
if ipt not in var_to_app_to_idx:
# This object here *must* be an OrderedDict, because
# This object here *must* be ordered, because
# we iterate over its keys when adding up the terms of the
# gradient on ipt. If it is a regular dict, the grad method
# will return something that is analytically correct, but
# whose order of doing additions depends on the memory
# location of the apply nodes.
var_to_app_to_idx[ipt] = OrderedDict()
var_to_app_to_idx[ipt] = {}
app_to_idx = var_to_app_to_idx[ipt]
if app not in app_to_idx:
app_to_idx[app] = []
......@@ -1052,7 +1046,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
"""
# build a dict mapping node to the terms node contributes to each of
# its inputs' gradients
term_dict = OrderedDict()
term_dict = {}
def access_term_cache(node):
"""Populates term_dict[node] and returns it"""
......@@ -1978,7 +1972,7 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
if expression.ndim == 0:
# expression is just a scalar, use grad
return format_as(
return as_list_or_tuple(
using_list,
using_tuple,
grad(
......@@ -2013,7 +2007,7 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
non_sequences=[expression] + wrt,
)
assert not updates, "Scan has returned a list of updates; this should not happen."
return format_as(using_list, using_tuple, jacobs)
return as_list_or_tuple(using_list, using_tuple, jacobs)
def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
......@@ -2093,7 +2087,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
not updates
), "Scan has returned a list of updates; this should not happen."
hessians.append(hess)
return format_as(using_list, using_tuple, hessians)
return as_list_or_tuple(using_list, using_tuple, hessians)
def _is_zero(x):
......@@ -2134,7 +2128,6 @@ class ConsiderConstant(ViewOp):
consider_constant_ = ConsiderConstant()
# I create a function only to have the doc show well.
def consider_constant(x):
"""
DEPRECATED: use zero_grad() or disconnected_grad() instead.
......
......@@ -278,8 +278,6 @@ class TestGrad:
g = grad(a1.outputs[0], a1.outputs[1], disconnected_inputs="ignore")
assert g.owner.op == at.fill
assert g.owner.inputs[1].data == 0
with pytest.raises(TypeError):
grad(a1.outputs[0], "wtf")
def test_NNone_rval(self):
# grad: Test returning some zero value from grad
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论