提交 abedb7fb authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Start using new API in tests that don't involve shared updates

上级 78293400
...@@ -23,10 +23,11 @@ jax = pytest.importorskip("jax") ...@@ -23,10 +23,11 @@ jax = pytest.importorskip("jax")
@pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)]) @pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)])
def test_scan_sit_sot(view): def test_scan_sit_sot(view):
x0 = pt.scalar("x0", dtype="float64") x0 = pt.scalar("x0", dtype="float64")
xs, _ = scan( xs = scan(
lambda xtm1: xtm1 + 1, lambda xtm1: xtm1 + 1,
outputs_info=[x0], outputs_info=[x0],
n_steps=10, n_steps=10,
return_updates=False,
) )
if view: if view:
xs = xs[view] xs = xs[view]
...@@ -37,10 +38,11 @@ def test_scan_sit_sot(view): ...@@ -37,10 +38,11 @@ def test_scan_sit_sot(view):
@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)]) @pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)])
def test_scan_mit_sot(view): def test_scan_mit_sot(view):
x0 = pt.vector("x0", dtype="float64", shape=(3,)) x0 = pt.vector("x0", dtype="float64", shape=(3,))
xs, _ = scan( xs = scan(
lambda xtm3, xtm1: xtm3 + xtm1 + 1, lambda xtm3, xtm1: xtm3 + xtm1 + 1,
outputs_info=[{"initial": x0, "taps": [-3, -1]}], outputs_info=[{"initial": x0, "taps": [-3, -1]}],
n_steps=10, n_steps=10,
return_updates=False,
) )
if view: if view:
xs = xs[view] xs = xs[view]
...@@ -57,13 +59,14 @@ def test_scan_multiple_mit_sot(view_x, view_y): ...@@ -57,13 +59,14 @@ def test_scan_multiple_mit_sot(view_x, view_y):
def step(xtm3, xtm1, ytm4, ytm2): def step(xtm3, xtm1, ytm4, ytm2):
return xtm3 + ytm4 + 1, xtm1 + ytm2 + 2 return xtm3 + ytm4 + 1, xtm1 + ytm2 + 2
[xs, ys], _ = scan( [xs, ys] = scan(
fn=step, fn=step,
outputs_info=[ outputs_info=[
{"initial": x0, "taps": [-3, -1]}, {"initial": x0, "taps": [-3, -1]},
{"initial": y0, "taps": [-4, -2]}, {"initial": y0, "taps": [-4, -2]},
], ],
n_steps=10, n_steps=10,
return_updates=False,
) )
if view_x: if view_x:
xs = xs[view_x] xs = xs[view_x]
...@@ -80,10 +83,8 @@ def test_scan_nit_sot(view): ...@@ -80,10 +83,8 @@ def test_scan_nit_sot(view):
xs = pt.vector("x0", dtype="float64", shape=(10,)) xs = pt.vector("x0", dtype="float64", shape=(10,))
ys, _ = scan( ys = scan(
lambda x: pt.exp(x), lambda x: pt.exp(x), outputs_info=[None], sequences=[xs], return_updates=False
outputs_info=[None],
sequences=[xs],
) )
if view: if view:
ys = ys[view] ys = ys[view]
...@@ -106,11 +107,12 @@ def test_scan_mit_mot(): ...@@ -106,11 +107,12 @@ def test_scan_mit_mot():
rho = pt.scalar("rho", dtype="float64") rho = pt.scalar("rho", dtype="float64")
x0 = pt.vector("xs", shape=(2,)) x0 = pt.vector("xs", shape=(2,))
y0 = pt.vector("ys", shape=(3,)) y0 = pt.vector("ys", shape=(3,))
[outs, _], _ = scan( [outs, _] = scan(
step, step,
outputs_info=[x0, {"initial": y0, "taps": [-3, -1]}], outputs_info=[x0, {"initial": y0, "taps": [-3, -1]}],
non_sequences=[rho], non_sequences=[rho],
n_steps=10, n_steps=10,
return_updates=False,
) )
grads = pt.grad(outs.sum(), wrt=[x0, y0, rho]) grads = pt.grad(outs.sum(), wrt=[x0, y0, rho])
compare_jax_and_py( compare_jax_and_py(
...@@ -191,10 +193,11 @@ def test_scan_rng_update(): ...@@ -191,10 +193,11 @@ def test_scan_rng_update():
@pytest.mark.xfail(raises=NotImplementedError) @pytest.mark.xfail(raises=NotImplementedError)
def test_scan_while(): def test_scan_while():
xs, _ = scan( xs = scan(
lambda x: (x + 1, until(x < 10)), lambda x: (x + 1, until(x < 10)),
outputs_info=[pt.zeros(())], outputs_info=[pt.zeros(())],
n_steps=100, n_steps=100,
return_updates=False,
) )
compare_jax_and_py([], [xs], []) compare_jax_and_py([], [xs], [])
...@@ -210,7 +213,7 @@ def test_scan_mitsot_with_nonseq(): ...@@ -210,7 +213,7 @@ def test_scan_mitsot_with_nonseq():
res.name = "y_t" res.name = "y_t"
return res return res
y_scan_pt, _ = scan( y_scan_pt = scan(
fn=input_step_fn, fn=input_step_fn,
outputs_info=[ outputs_info=[
{ {
...@@ -223,6 +226,7 @@ def test_scan_mitsot_with_nonseq(): ...@@ -223,6 +226,7 @@ def test_scan_mitsot_with_nonseq():
non_sequences=[a_pt], non_sequences=[a_pt],
n_steps=10, n_steps=10,
name="y_scan", name="y_scan",
return_updates=False,
) )
y_scan_pt.name = "y" y_scan_pt.name = "y"
y_scan_pt.owner.inputs[0].name = "y_all" y_scan_pt.owner.inputs[0].name = "y_all"
...@@ -241,11 +245,12 @@ def test_nd_scan_sit_sot(x0_func, A_func): ...@@ -241,11 +245,12 @@ def test_nd_scan_sit_sot(x0_func, A_func):
k = 3 k = 3
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan( xs = scan(
lambda X, A: A @ X, lambda X, A: A @ X,
non_sequences=[A], non_sequences=[A],
outputs_info=[x0], outputs_info=[x0],
n_steps=n_steps, n_steps=n_steps,
return_updates=False,
) )
x0_val = ( x0_val = (
...@@ -267,11 +272,12 @@ def test_nd_scan_sit_sot_with_seq(): ...@@ -267,11 +272,12 @@ def test_nd_scan_sit_sot_with_seq():
A = pt.matrix("A", shape=(k, k)) A = pt.matrix("A", shape=(k, k))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan( xs = scan(
lambda X, A: A @ X, lambda X, A: A @ X,
non_sequences=[A], non_sequences=[A],
sequences=[x], sequences=[x],
n_steps=n_steps, n_steps=n_steps,
return_updates=False,
) )
x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k) x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
...@@ -287,11 +293,12 @@ def test_nd_scan_mit_sot(): ...@@ -287,11 +293,12 @@ def test_nd_scan_mit_sot():
B = pt.matrix("B", shape=(3, 3)) B = pt.matrix("B", shape=(3, 3))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan( xs = scan(
lambda xtm3, xtm1, A, B: A @ xtm3 + B @ xtm1, lambda xtm3, xtm1, A, B: A @ xtm3 + B @ xtm1,
outputs_info=[{"initial": x0, "taps": [-3, -1]}], outputs_info=[{"initial": x0, "taps": [-3, -1]}],
non_sequences=[A, B], non_sequences=[A, B],
n_steps=10, n_steps=10,
return_updates=False,
) )
x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3) x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3)
...@@ -310,12 +317,13 @@ def test_nd_scan_sit_sot_with_carry(): ...@@ -310,12 +317,13 @@ def test_nd_scan_sit_sot_with_carry():
return A @ x, x.sum() return A @ x, x.sum()
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan( xs = scan(
step, step,
outputs_info=[x0, None], outputs_info=[x0, None],
non_sequences=[A], non_sequences=[A],
n_steps=10, n_steps=10,
mode=get_mode("JAX"), mode=get_mode("JAX"),
return_updates=False,
) )
x0_val = np.arange(3, dtype=config.floatX) x0_val = np.arange(3, dtype=config.floatX)
...@@ -329,7 +337,13 @@ def test_default_mode_excludes_incompatible_rewrites(): ...@@ -329,7 +337,13 @@ def test_default_mode_excludes_incompatible_rewrites():
# See issue #426 # See issue #426
A = matrix("A") A = matrix("A")
B = matrix("B") B = matrix("B")
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2) out = scan(
lambda a, b: a @ b,
outputs_info=[A],
non_sequences=[B],
n_steps=2,
return_updates=False,
)
compare_jax_and_py([A, B], [out], [np.eye(3), np.eye(3)], jax_mode="JAX") compare_jax_and_py([A, B], [out], [np.eye(3), np.eye(3)], jax_mode="JAX")
...@@ -353,8 +367,11 @@ def test_dynamic_sequence_length(): ...@@ -353,8 +367,11 @@ def test_dynamic_sequence_length():
x = pt.tensor("x", shape=(None, 3)) x = pt.tensor("x", shape=(None, 3))
out, _ = scan( out = scan(
lambda x: inc_without_static_shape(x), outputs_info=[None], sequences=[x] lambda x: inc_without_static_shape(x),
outputs_info=[None],
sequences=[x],
return_updates=False,
) )
f = function([x], out, mode=get_mode("JAX").excluding("scan")) f = function([x], out, mode=get_mode("JAX").excluding("scan"))
assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1 assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1
...@@ -364,10 +381,11 @@ def test_dynamic_sequence_length(): ...@@ -364,10 +381,11 @@ def test_dynamic_sequence_length():
np.testing.assert_allclose(f(np.zeros((0, 3))), np.empty((0, 3))) np.testing.assert_allclose(f(np.zeros((0, 3))), np.empty((0, 3)))
# With known static shape we should always manage, regardless of the internal implementation # With known static shape we should always manage, regardless of the internal implementation
out2, _ = scan( out2 = scan(
lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape), lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape),
outputs_info=[None], outputs_info=[None],
sequences=[x], sequences=[x],
return_updates=False,
) )
f2 = function([x], out2, mode=get_mode("JAX").excluding("scan")) f2 = function([x], out2, mode=get_mode("JAX").excluding("scan"))
np.testing.assert_allclose(f2([[1, 2, 3]]), np.array([[2, 3, 4]])) np.testing.assert_allclose(f2([[1, 2, 3]]), np.array([[2, 3, 4]]))
...@@ -418,11 +436,12 @@ def SEIR_model_logp(): ...@@ -418,11 +436,12 @@ def SEIR_model_logp():
it1 = it0 + ct0 - dt0 it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1 return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = scan( (st, et, it, logp_c_all, logp_d_all) = scan(
fn=seir_one_step, fn=seir_one_step,
sequences=[C_t, D_t], sequences=[C_t, D_t],
outputs_info=[st0, et0, it0, None, None], outputs_info=[st0, et0, it0, None, None],
non_sequences=[beta, gamma, delta], non_sequences=[beta, gamma, delta],
return_updates=False,
) )
st.name = "S_t" st.name = "S_t"
et.name = "E_t" et.name = "E_t"
...@@ -511,11 +530,12 @@ def cyclical_reduction(): ...@@ -511,11 +530,12 @@ def cyclical_reduction():
max_iter = 100 max_iter = 100
tol = 1e-7 tol = 1e-7
(*_, A1_hat, norm, _n_steps), _ = scan( (*_, A1_hat, norm, _n_steps) = scan(
step, step,
outputs_info=[A, B, C, B, norm, step_num], outputs_info=[A, B, C, B, norm, step_num],
non_sequences=[tol], non_sequences=[tol],
n_steps=max_iter, n_steps=max_iter,
return_updates=False,
) )
A1_hat = A1_hat[-1] A1_hat = A1_hat[-1]
......
...@@ -206,11 +206,12 @@ def test_scan_multiple_output(benchmark): ...@@ -206,11 +206,12 @@ def test_scan_multiple_output(benchmark):
it1 = it0 + ct0 - dt0 it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1 return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = scan( (st, et, it, logp_c_all, logp_d_all) = scan(
fn=seir_one_step, fn=seir_one_step,
sequences=[pt_C, pt_D], sequences=[pt_C, pt_D],
outputs_info=[st0, et0, it0, logp_c, logp_d], outputs_info=[st0, et0, it0, logp_c, logp_d],
non_sequences=[beta, gamma, delta], non_sequences=[beta, gamma, delta],
return_updates=False,
) )
st.name = "S_t" st.name = "S_t"
et.name = "E_t" et.name = "E_t"
...@@ -268,7 +269,7 @@ def test_scan_tap_output(): ...@@ -268,7 +269,7 @@ def test_scan_tap_output():
y_t.name = "y_t" y_t.name = "y_t"
return x_t, y_t, pt.fill((10,), z_t) return x_t, y_t, pt.fill((10,), z_t)
scan_res, _ = scan( scan_res = scan(
fn=input_step_fn, fn=input_step_fn,
sequences=[ sequences=[
{ {
...@@ -297,6 +298,7 @@ def test_scan_tap_output(): ...@@ -297,6 +298,7 @@ def test_scan_tap_output():
n_steps=5, n_steps=5,
name="yz_scan", name="yz_scan",
strict=True, strict=True,
return_updates=False,
) )
test_input_vals = [ test_input_vals = [
...@@ -312,11 +314,12 @@ def test_scan_while(): ...@@ -312,11 +314,12 @@ def test_scan_while():
return previous_power * 2, until(previous_power * 2 > max_value) return previous_power * 2, until(previous_power * 2 > max_value)
max_value = pt.scalar() max_value = pt.scalar()
values, _ = scan( values = scan(
power_of_2, power_of_2,
outputs_info=pt.constant(1.0), outputs_info=pt.constant(1.0),
non_sequences=max_value, non_sequences=max_value,
n_steps=1024, n_steps=1024,
return_updates=False,
) )
test_input_vals = [ test_input_vals = [
...@@ -331,11 +334,12 @@ def test_scan_multiple_none_output(): ...@@ -331,11 +334,12 @@ def test_scan_multiple_none_output():
def power_step(prior_result, x): def power_step(prior_result, x):
return prior_result * x, prior_result * x * x, prior_result * x * x * x return prior_result * x, prior_result * x * x, prior_result * x * x * x
result, _ = scan( result = scan(
power_step, power_step,
non_sequences=[A], non_sequences=[A],
outputs_info=[pt.ones_like(A), None, None], outputs_info=[pt.ones_like(A), None, None],
n_steps=3, n_steps=3,
return_updates=False,
) )
test_input_vals = (np.array([1.0, 2.0]),) test_input_vals = (np.array([1.0, 2.0]),)
compare_numba_and_py([A], result, test_input_vals) compare_numba_and_py([A], result, test_input_vals)
...@@ -343,8 +347,12 @@ def test_scan_multiple_none_output(): ...@@ -343,8 +347,12 @@ def test_scan_multiple_none_output():
def test_grad_sitsot(): def test_grad_sitsot():
def get_sum_of_grad(inp): def get_sum_of_grad(inp):
scan_outputs, _updates = scan( scan_outputs = scan(
fn=lambda x: x * 2, outputs_info=[inp], n_steps=5, mode="NUMBA" fn=lambda x: x * 2,
outputs_info=[inp],
n_steps=5,
mode="NUMBA",
return_updates=False,
) )
return grad(scan_outputs.sum(), inp).sum() return grad(scan_outputs.sum(), inp).sum()
...@@ -362,8 +370,11 @@ def test_mitmots_basic(): ...@@ -362,8 +370,11 @@ def test_mitmots_basic():
def inner_fct(seq, state_old, state_current): def inner_fct(seq, state_old, state_current):
return state_old * 2 + state_current + seq return state_old * 2 + state_current + seq
out, _ = scan( out = scan(
inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]} inner_fct,
sequences=seq,
outputs_info={"initial": init_x, "taps": [-2, -1]},
return_updates=False,
) )
g_outs = grad(out.sum(), [seq, init_x]) g_outs = grad(out.sum(), [seq, init_x])
...@@ -383,10 +394,11 @@ def test_mitmots_basic(): ...@@ -383,10 +394,11 @@ def test_mitmots_basic():
def test_inner_graph_optimized(): def test_inner_graph_optimized():
"""Test that inner graph of Scan is optimized""" """Test that inner graph of Scan is optimized"""
xs = vector("xs") xs = vector("xs")
seq, _ = scan( seq = scan(
fn=lambda x: log(1 + x), fn=lambda x: log(1 + x),
sequences=[xs], sequences=[xs],
mode=get_mode("NUMBA"), mode=get_mode("NUMBA"),
return_updates=False,
) )
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise # Disable scan pushout, in which case the whole scan is replaced by an Elemwise
...@@ -421,13 +433,14 @@ def test_vector_taps_benchmark(benchmark): ...@@ -421,13 +433,14 @@ def test_vector_taps_benchmark(benchmark):
sitsot2 = (sitsot1 + mitsot3) / np.sqrt(2) sitsot2 = (sitsot1 + mitsot3) / np.sqrt(2)
return mitsot3, sitsot2 return mitsot3, sitsot2
outs, _ = scan( outs = scan(
fn=step, fn=step,
sequences=[seq1, seq2], sequences=[seq1, seq2],
outputs_info=[ outputs_info=[
dict(initial=mitsot_init, taps=[-2, -1]), dict(initial=mitsot_init, taps=[-2, -1]),
dict(initial=sitsot_init, taps=[-1]), dict(initial=sitsot_init, taps=[-1]),
], ],
return_updates=False,
) )
rng = np.random.default_rng(474) rng = np.random.default_rng(474)
...@@ -468,7 +481,7 @@ def test_inplace_taps(n_steps_constant): ...@@ -468,7 +481,7 @@ def test_inplace_taps(n_steps_constant):
y = ytm1 + 1 + ytm2 + a y = ytm1 + 1 + ytm2 + a
return z, x, z + x + y, y return z, x, z + x + y, y
[zs, xs, ws, ys], _ = scan( [zs, xs, ws, ys] = scan(
fn=step, fn=step,
outputs_info=[ outputs_info=[
dict(initial=z0, taps=[-3, -1]), dict(initial=z0, taps=[-3, -1]),
...@@ -478,6 +491,7 @@ def test_inplace_taps(n_steps_constant): ...@@ -478,6 +491,7 @@ def test_inplace_taps(n_steps_constant):
], ],
non_sequences=[a], non_sequences=[a],
n_steps=n_steps, n_steps=n_steps,
return_updates=False,
) )
numba_fn, _ = compare_numba_and_py( numba_fn, _ = compare_numba_and_py(
[n_steps] * (not n_steps_constant) + [a, x0, y0, z0], [n_steps] * (not n_steps_constant) + [a, x0, y0, z0],
...@@ -529,10 +543,11 @@ def test_inplace_taps(n_steps_constant): ...@@ -529,10 +543,11 @@ def test_inplace_taps(n_steps_constant):
class TestScanSITSOTBuffer: class TestScanSITSOTBuffer:
def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None): def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None):
x0 = pt.vector(shape=(op_size,), dtype="float64") x0 = pt.vector(shape=(op_size,), dtype="float64")
xs, _ = pytensor.scan( xs = pytensor.scan(
fn=lambda xtm1: (xtm1 + 1), fn=lambda xtm1: (xtm1 + 1),
outputs_info=[x0], outputs_info=[x0],
n_steps=n_steps - 1, # 1- makes it easier to align/misalign n_steps=n_steps - 1, # 1- makes it easier to align/misalign
return_updates=False,
) )
if buffer_size == "unit": if buffer_size == "unit":
xs_kept = xs[-1] # Only last state is used xs_kept = xs[-1] # Only last state is used
...@@ -588,12 +603,13 @@ class TestScanMITSOTBuffer: ...@@ -588,12 +603,13 @@ class TestScanMITSOTBuffer:
init_x = pt.vector("init_x", shape=(2,)) init_x = pt.vector("init_x", shape=(2,))
n_steps = pt.iscalar("n_steps") n_steps = pt.iscalar("n_steps")
output, _ = scan( output = scan(
f_pow2, f_pow2,
sequences=[], sequences=[],
outputs_info=[{"initial": init_x, "taps": [-2, -1]}], outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
non_sequences=[], non_sequences=[],
n_steps=n_steps_val if constant_n_steps else n_steps, n_steps=n_steps_val if constant_n_steps else n_steps,
return_updates=False,
) )
init_x_val = np.array([1.0, 2.0], dtype=init_x.type.dtype) init_x_val = np.array([1.0, 2.0], dtype=init_x.type.dtype)
......
差异被折叠。
...@@ -170,11 +170,12 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed): ...@@ -170,11 +170,12 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
A = tensor("A", shape=(3, 3)) A = tensor("A", shape=(3, 3))
x0 = tensor("b", shape=(3, 4)) x0 = tensor("b", shape=(3, 4))
xs, _ = scan( xs = scan(
lambda xtm1, A: solve(A, xtm1, assume_a=assume_a, transposed=transposed), lambda xtm1, A: solve(A, xtm1, assume_a=assume_a, transposed=transposed),
outputs_info=[x0], outputs_info=[x0],
non_sequences=[A], non_sequences=[A],
n_steps=10, n_steps=10,
return_updates=False,
) )
fn_no_opt = function( fn_no_opt = function(
......
...@@ -694,10 +694,11 @@ def test_blockwise_grad_core_type(): ...@@ -694,10 +694,11 @@ def test_blockwise_grad_core_type():
def test_scan_gradient_core_type(): def test_scan_gradient_core_type():
n_steps = 3 n_steps = 3
seq = tensor("seq", shape=(n_steps, 1), dtype="float64") seq = tensor("seq", shape=(n_steps, 1), dtype="float64")
out, _ = scan( out = scan(
lambda s: s, lambda s: s,
sequences=[seq], sequences=[seq],
n_steps=n_steps, n_steps=n_steps,
return_updates=False,
) )
vec_seq = tensor("vec_seq", shape=(None, n_steps, 1), dtype="float64") vec_seq = tensor("vec_seq", shape=(None, n_steps, 1), dtype="float64")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论