提交 57c388a7 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add axis to LogSoftmax

上级 595ed184
...@@ -208,8 +208,10 @@ def jax_funcify_Softmax(op, **kwargs): ...@@ -208,8 +208,10 @@ def jax_funcify_Softmax(op, **kwargs):
@jax_funcify.register(LogSoftmax) @jax_funcify.register(LogSoftmax)
def jax_funcify_LogSoftmax(op, **kwargs): def jax_funcify_LogSoftmax(op, **kwargs):
axis = op.axis
def log_softmax(x): def log_softmax(x):
return jax.nn.log_softmax(x) return jax.nn.log_softmax(x, axis=axis)
return log_softmax return log_softmax
......
...@@ -430,15 +430,22 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): ...@@ -430,15 +430,22 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
x_at = node.inputs[0] x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype) x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis
# np.max(x, axis=1) if axis is not None:
reduce_max = create_axis_reducer(np.maximum, -np.inf, 1, x_at.ndim, x_dtype) reduce_max = create_axis_reducer(
# np.sum(x, axis=1, keepdims=True) np.maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
reduce_sum = create_axis_reducer(np.add, 0.0, 1, x_at.ndim, x_dtype, keepdims=True) )
reduce_sum = create_axis_reducer(
np.add, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
)
else:
reduce_max = np.max
reduce_sum = np.sum
@numba.njit @numba.njit
def log_softmax(x): def log_softmax(x):
xdev = x - np.expand_dims(reduce_max(x), -1) xdev = x - reduce_max(x)
lsm = xdev - np.log(reduce_sum(np.exp(xdev))) lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
return lsm return lsm
......
...@@ -27,7 +27,6 @@ from aesara.tensor.nnet.basic import ( ...@@ -27,7 +27,6 @@ from aesara.tensor.nnet.basic import (
graph_merge_softmax_with_crossentropy_softmax, graph_merge_softmax_with_crossentropy_softmax,
h_softmax, h_softmax,
logsoftmax, logsoftmax,
logsoftmax_op,
prepend_0_to_each_row, prepend_0_to_each_row,
prepend_1_to_each_row, prepend_1_to_each_row,
prepend_scalar_to_each_row, prepend_scalar_to_each_row,
......
...@@ -822,34 +822,34 @@ class LogSoftmax(COp): ...@@ -822,34 +822,34 @@ class LogSoftmax(COp):
""" """
__props__ = () nin = 1
nout = 1
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, x): def make_node(self, x):
x = aet.as_tensor_variable(x) x = aet.as_tensor_variable(x)
if x.type.ndim not in (1, 2) or x.type.dtype not in float_dtypes:
raise ValueError(f"x must be 1-d or 2-d tensor of floats. Got {x.type}") if self.axis is not None and (self.axis >= x.ndim or self.axis < -x.ndim):
if x.ndim == 1: raise ValueError(
warnings.warn( f"LogSoftmax axis(={self.axis}) out of bounds for {x.ndim}D array {x}"
"If x is a vector, LogSoftmax will not automatically pad x "
"anymore in next releases. If you need it, please do it manually. The "
"vector case is gonna be supported soon and the output will be a vector.",
category=PendingDeprecationWarning,
stacklevel=4,
) )
x = shape_padleft(x, n_ones=1)
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
def perform(self, node, input_storage, output_storage): def perform(self, node, input_storage, output_storage):
(x,) = input_storage (x,) = input_storage
xdev = x - x.max(axis=1)[:, None] (z,) = output_storage
lsm = xdev - np.log(np.sum(np.exp(xdev), axis=1, keepdims=True)) z[0] = scipy.special.log_softmax(x, axis=self.axis)
output_storage[0][0] = lsm
def grad(self, inp, grads): def grad(self, inp, grads):
(x,) = inp (x,) = inp
sm = softmax_legacy(x) sm = Softmax(axis=self.axis)(x)
return [grads[0] - aet_sum(grads[0], axis=1, keepdims=True) * sm] return [grads[0] - aet_sum(grads[0], axis=self.axis, keepdims=True) * sm]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
# I think the Jacobian is symmetric so the R_op # I think the Jacobian is symmetric so the R_op
...@@ -864,100 +864,196 @@ class LogSoftmax(COp): ...@@ -864,100 +864,196 @@ class LogSoftmax(COp):
def c_headers(self, **kwargs): def c_headers(self, **kwargs):
return ["<cmath>"] return ["<cmath>"]
@staticmethod def c_code(self, node, name, inp, out, sub):
def c_code_template(dtype): (x,) = inp
init_decl = """ (sm,) = out
npy_intp* Nx = PyArray_DIMS(%(x)s); axis = self.axis if self.axis is not None else np.MAXDIMS
npy_intp Sx1 = 0; fail = sub["fail"]
npy_intp Ssm1 = 0;
if (PyArray_NDIM(%(x)s) != 2) return dedent(
{ f"""
PyErr_SetString(PyExc_ValueError, "not a 2d tensor"); PyArrayObject* op[2];
%(fail)s; npy_uint32 op_flags[2];
} npy_uint32 iter_flags;
if ((PyArray_TYPE(%(x)s) != NPY_DOUBLE) && NpyIter* iter;
(PyArray_TYPE(%(x)s) != NPY_FLOAT)) NpyIter_IterNextFunc* get_next;
{ char** data_ptr;
int x_ndim = PyArray_NDIM({x});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({x}) != NPY_DOUBLE) &&
(PyArray_TYPE({x}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "not a float"); PyErr_SetString(PyExc_TypeError, "not a float");
%(fail)s; {fail}
} }}
if ((NULL == %(sm)s) if (axis < 0) axis = x_ndim + axis;
|| (PyArray_DIMS(%(sm)s)[0] != PyArray_DIMS(%(x)s)[0]) if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
|| (PyArray_DIMS(%(sm)s)[1] != PyArray_DIMS(%(x)s)[1])) {{
{ PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax");
Py_XDECREF(%(sm)s); {fail}
%(sm)s = (PyArrayObject*)PyArray_SimpleNew( }}
2, PyArray_DIMS(%(x)s),
PyArray_TYPE(%(x)s));
if(!%(sm)s) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc sm output");
%(fail)s
}
}
Sx1 = PyArray_STRIDES(%(x)s)[1]/sizeof(dtype_%(x)s);
Ssm1 = PyArray_STRIDES(%(sm)s)[1]/sizeof(dtype_%(sm)s);
"""
begin_row_loop = """ // Allocate Output Array
// minibatch loop if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim)))
for (size_t i = 0; i < Nx[0]; ++i) {{
{ Py_XDECREF({sm});
size_t j; {sm} = (PyArrayObject*)PyArray_SimpleNew(x_ndim, PyArray_DIMS({x}), PyArray_TYPE({x}));
double sum = 0.0; if(!{sm}) {{
PyErr_SetString(PyExc_MemoryError, "failed to alloc LogSoftmax output");
{fail}
}}
}}
const dtype_%(x)s* __restrict__ x_i = (dtype_%(x)s*)( // Create numpy iterator
PyArray_BYTES(%(x)s) + PyArray_STRIDES(%(x)s)[0] * i); op[0] = {x};
dtype_%(sm)s* __restrict__ sm_i = (dtype_%(sm)s*)( op[1] = {sm};
PyArray_BYTES(%(sm)s) + PyArray_STRIDES(%(sm)s)[0] * i); op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
2,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
dtype_%(sm)s row_max = x_i[0]; if (iter == NULL)
// Get the maximum value of the row {{
for (j = 1; j < Nx[1]; ++j) PyErr_SetString(PyExc_MemoryError, "failed to create LogSoftmax iterator");
{ {fail}
dtype_%(sm)s x_ij = x_i[j * Sx1] ; }}
row_max = (x_ij > row_max) ? x_ij : row_max;
}
"""
inside_row_loop = """ // LogSoftmax is applied across the entire array
// Compute xdev and sum(exp(xdev), axis=1) if (!iterate_axis)
double xdev_exp_row_sum = 0.0; {{
for (j = 0; j < Nx[1]; j++) get_next = NpyIter_GetIterNext(iter, NULL);
{ if (get_next == NULL)
// use sm_i to temporary store xdev {{
sm_i[j * Ssm1] = (dtype_%(sm)s) (x_i[j * Sx1] - row_max); NpyIter_Deallocate(iter);
xdev_exp_row_sum += exp(sm_i[j * Ssm1]); PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax GetIterNext");
} {fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Write sm = xdev - log(sum(exp(xdev), axis=1)) // Find axis max
xdev_exp_row_sum = log(xdev_exp_row_sum); dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
for (j = 0; j < Nx[1]; ++j) dtype_{x} max = *x_ptr;
{ if (get_next(iter))
sm_i[j * Ssm1] -= (dtype_%(sm)s) xdev_exp_row_sum; {{
} do
""" {{
end_row_loop = """ dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
} max = (*x_ptr > max)? *x_ptr : max;
""" }} while(get_next(iter));
return (init_decl, begin_row_loop, inside_row_loop, end_row_loop) }}
def c_code(self, node, name, inp, out, sub): // Reset Iterator
(x,) = inp if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
(sm,) = out {{
code_template = "".join( PyErr_SetString(PyExc_RuntimeError, "Failed to reset LogSoftmax iterator");
self.c_code_template(node.inputs[0].type.dtype_specs()[1]) {fail}
}}
// Compute xdev and sum(exp(xdev))
dtype_{sm} sum_exp_xdev = 0.0;
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr = (dtype_{sm})((*x_ptr) - max);
sum_exp_xdev += exp(*sm_ptr);
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset LogSoftmax iterator");
{fail}
}}
// Subtract log(sum(exp(xdev)))
dtype_{sm} log_sum_exp_xdev = log(sum_exp_xdev);
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr -= log_sum_exp_xdev;
}} while(get_next(iter));
}}
// LogSoftmax is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({x}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax axis strides");
{fail}
}}
npy_intp x_axis_stride = axis_stride[0] / sizeof(dtype_{x});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove LogSoftmax axis from iterator");
{fail}
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{x}* x_axis = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
// Find axis max
dtype_{x} max = x_axis[0];
for (npy_intp i = 1; i < axis_size; i++)
{{
dtype_{x} x_val = x_axis[i * x_axis_stride];
max = (x_val > max)? x_val : max;
}}
// Compute xdev and sum(exp(xdev))
dtype_{sm} sum_exp_xdev = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] = (dtype_{x})(x_axis[i * x_axis_stride] - max);
sum_exp_xdev += exp(sm_axis[i * sm_axis_stride]);
}}
// Subtract log(sum(exp(xdev))
dtype_{sm} log_sum_exp_xdev = log(sum_exp_xdev);
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] -= log_sum_exp_xdev;
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
) )
return code_template % dict(locals(), **sub)
@staticmethod @staticmethod
def c_code_cache_version(): def c_code_cache_version():
return (0,) return (1,)
logsoftmax_op = LogSoftmax()
# This is not registered in stabilize, as it cause some crossentropy # This is not registered in stabilize, as it cause some crossentropy
...@@ -975,11 +1071,10 @@ def local_logsoftmax(fgraph, node): ...@@ -975,11 +1071,10 @@ def local_logsoftmax(fgraph, node):
and isinstance(node.op.scalar_op, aes.Log) and isinstance(node.op.scalar_op, aes.Log)
and len(node.inputs) == 1 and len(node.inputs) == 1
and node.inputs[0].owner is not None and node.inputs[0].owner is not None
and node.inputs[0].owner.op == softmax_legacy and isinstance(node.inputs[0].owner.op, Softmax)
and node.inputs[0].ndim == 2
): ):
inVars = node.inputs[0].owner.inputs[0] inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax() new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis)
ret = new_op(inVars) ret = new_op(inVars)
ret.tag.values_eq_approx = values_eq_approx_remove_inf ret.tag.values_eq_approx = values_eq_approx_remove_inf
copy_stack_trace([node.inputs[0], node.outputs[0]], ret) copy_stack_trace([node.inputs[0], node.outputs[0]], ret)
...@@ -1054,8 +1149,23 @@ def softmax(c, axis=UNSET_AXIS): ...@@ -1054,8 +1149,23 @@ def softmax(c, axis=UNSET_AXIS):
return Softmax(axis=axis)(c) return Softmax(axis=axis)(c)
def logsoftmax(c): def logsoftmax(c, axis=UNSET_AXIS):
return logsoftmax_op(c) if axis is UNSET_AXIS:
warnings.warn(
"logsoftmax now accepts an axis argument. For backwards-compatibility it defaults to -1 when not specified, "
"but in the future the default will be `None`.\nTo suppress this warning specify axis explicitly.",
FutureWarning,
)
axis = -1
c = as_tensor_variable(c)
if c.ndim == 1:
# TODO: Create Specific warning type that can be suppressed?
warnings.warn(
"Softmax no longer converts a vector to a row matrix.",
UserWarning,
)
return LogSoftmax(axis=axis)(c)
@register_specialize("fast_compile_gpu") @register_specialize("fast_compile_gpu")
......
...@@ -969,16 +969,21 @@ def test_nnet(): ...@@ -969,16 +969,21 @@ def test_nnet():
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.logsoftmax(x)
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = aet_nnet.softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis): def test_logsoftmax(axis):
x = matrix("x") x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = aet_nnet.softmax(x, axis=axis) out = aet_nnet.logsoftmax(x, axis=axis)
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
...@@ -1930,20 +1930,27 @@ def test_Softmax(x, axis, exc): ...@@ -1930,20 +1930,27 @@ def test_Softmax(x, axis, exc):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, exc", "x, axis, exc",
[ [
( (
set_test_value(aet.vector(), rng.random(size=(2,)).astype(config.floatX)), set_test_value(aet.vector(), rng.random(size=(2,)).astype(config.floatX)),
None, None,
None,
), ),
( (
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0,
None,
),
(
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
1,
None, None,
), ),
], ],
) )
def test_LogSoftmax(x, exc): def test_LogSoftmax(x, axis, exc):
g = nnetb.LogSoftmax()(x) g = nnetb.LogSoftmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
......
...@@ -47,7 +47,6 @@ from aesara.tensor.nnet.basic import ( ...@@ -47,7 +47,6 @@ from aesara.tensor.nnet.basic import (
elu, elu,
h_softmax, h_softmax,
logsoftmax, logsoftmax,
logsoftmax_op,
relu, relu,
selu, selu,
sigmoid_binary_crossentropy, sigmoid_binary_crossentropy,
...@@ -205,47 +204,28 @@ class TestSoftmaxWithBias(utt.InferShapeTester): ...@@ -205,47 +204,28 @@ class TestSoftmaxWithBias(utt.InferShapeTester):
class TestLogSoftmax(utt.InferShapeTester): class TestLogSoftmax(utt.InferShapeTester):
def test_basic(self): @pytest.mark.parametrize("column", [0, 1, 2, 3])
def f(a): @pytest.mark.parametrize("axis", [None, 0, 1])
return logsoftmax_op(a)[:, 0] def test_matrix_grad(self, axis, column):
utt.verify_grad(f, [np.random.random((3, 4))])
def f(a):
return logsoftmax_op(a)[:, 1]
utt.verify_grad(f, [np.random.random((3, 4))])
def f(a):
return logsoftmax_op(a)[:, 2]
utt.verify_grad(f, [np.random.random((3, 4))])
def f(a):
return logsoftmax_op(a)[:, 3]
utt.verify_grad(f, [np.random.random((3, 4))])
def test_matrix(self):
def f(a): def f(a):
return logsoftmax_op(a) return logsoftmax(a, axis=axis)[:, column]
utt.verify_grad(f, [np.random.random((3, 4))]) utt.verify_grad(f, [np.random.random((3, 4))])
def test_vector(self): def test_vector_perform(self):
x = vector() x = vector()
f = aesara.function([x], logsoftmax_op(x)) f = aesara.function([x], logsoftmax(x, axis=None))
xv = np.random.randn(6).astype(config.floatX) xv = np.random.randn(6).astype(config.floatX)
assert np.allclose(f(xv), np.log(np.exp(xv) / np.exp(xv).sum())) assert np.allclose(f(xv), sp.log_softmax(xv))
def test_vector_grad(self): def test_vector_grad(self):
def f(a): def f(a):
return logsoftmax_op(a) return logsoftmax(a, axis=None)
utt.verify_grad(f, [np.random.random((4))]) utt.verify_grad(f, [np.random.random((4))])
def test_allclose(self): def test_matrix_perform_and_opt(self):
m = config.mode m = config.mode
m = aesara.compile.get_mode(m) m = aesara.compile.get_mode(m)
m.check_isfinite = False m.check_isfinite = False
...@@ -284,18 +264,15 @@ class TestLogSoftmax(utt.InferShapeTester): ...@@ -284,18 +264,15 @@ class TestLogSoftmax(utt.InferShapeTester):
grad_ = f3(a, b) grad_ = f3(a, b)
assert not np.any(np.isnan(grad_)) assert not np.any(np.isnan(grad_))
def test_isclose(self): @pytest.mark.parametrize("axis", [None, 0, -1])
def f(a): def test_local_logsoftmax_opt(self, axis):
return logsoftmax_op(a)
def test_local_softmax_optimization(self):
# Test the Logsoftmax substitution # Test the Logsoftmax substitution
# #
# Check that Log(Softmax(x)) is substituted with Logsoftmax(x). Note that # Check that Log(Softmax(x)) is substituted with Logsoftmax(x). Note that
# only the forward pass is checked (i.e., doesn't check the gradient) # only the forward pass is checked (i.e., doesn't check the gradient)
x, y = matrices("xy") x = matrix("x")
sm = softmax(x) sm = softmax(x, axis=axis)
logsm = log(sm) logsm = log(sm)
f = aesara.function([x], logsm) f = aesara.function([x], logsm)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax) assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
...@@ -351,6 +328,9 @@ class TestLogSoftmax(utt.InferShapeTester): ...@@ -351,6 +328,9 @@ class TestLogSoftmax(utt.InferShapeTester):
assert softmax_grad_legacy in [n.op for n in fgraph.toposort()] assert softmax_grad_legacy in [n.op for n in fgraph.toposort()]
def test_valid_axis(self):
valid_axis_tester(LogSoftmax)
class TestSoftmaxGrad(utt.InferShapeTester): class TestSoftmaxGrad(utt.InferShapeTester):
def test_infer_shape(self): def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论