提交 4aea87c2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug when taking the L_op of a Scan with mit-mot and disconnected output gradients

上级 49cf9d22
...@@ -2509,13 +2509,25 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2509,13 +2509,25 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
return rval return rval
var_mappings = self.get_oinp_iinp_iout_oout_mappings() var_mappings = self.get_oinp_iinp_iout_oout_mappings()
dC_dinps_t = [None for inp in diff_inputs]
disconnected_dC_dinps_t = [True for inp in diff_inputs] disconnected_dC_dinps_t = [True for inp in diff_inputs]
n_mit_mot_outs = info.n_mit_mot_outs
# In the case of mit-mot there can be more inner outputs than outer ones
n_extra_mit_mot_outs = n_mit_mot_outs - info.n_mit_mot
idx_nitsot_out_start = n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot
idx_nitsot_out_end = idx_nitsot_out_start + info.n_nit_sot
# Create dummy variables for the internal input gradients
states = (
self.inner_mitmot(self_inputs)
+ self.inner_mitsot(self_inputs)
+ self.inner_sitsot(self_inputs)
)
dC_dXts = [] dC_dXts = []
Xts = [] Xts = []
for idx, Xt in enumerate(diff_outputs): for idx, Xt in enumerate(diff_outputs):
# We are looking for x[t-1] for a given x[t] # We are looking for x[t-1] for a given x[t]
if idx >= info.n_mit_mot_outs: if idx >= n_mit_mot_outs:
Xt_placeholder = safe_new(Xt) Xt_placeholder = safe_new(Xt)
Xts.append(Xt_placeholder) Xts.append(Xt_placeholder)
...@@ -2523,9 +2535,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2523,9 +2535,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# or not. NOTE : This cannot be done by using # or not. NOTE : This cannot be done by using
# "if Xt not in self.inner_nitsot_outs(self_outputs)" because # "if Xt not in self.inner_nitsot_outs(self_outputs)" because
# the exact same variable can be used as multiple outputs. # the exact same variable can be used as multiple outputs.
idx_nitsot_start = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot if idx < idx_nitsot_out_start or idx >= idx_nitsot_out_end:
idx_nitsot_end = idx_nitsot_start + info.n_nit_sot
if idx < idx_nitsot_start or idx >= idx_nitsot_end:
# What we do here is loop through dC_douts and collect all # What we do here is loop through dC_douts and collect all
# those that are connected to the specific one and do an # those that are connected to the specific one and do an
# upcast on all of their dtypes to get the dtype for this # upcast on all of their dtypes to get the dtype for this
...@@ -2533,12 +2543,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2533,12 +2543,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# specific previous step is defined or not is done somewhere # specific previous step is defined or not is done somewhere
# else. # else.
dtypes = [] dtypes = []
states = (
self.inner_mitmot(self_inputs)
+ self.inner_mitsot(self_inputs)
+ self.inner_sitsot(self_inputs)
)
for pos, inp in enumerate(states): for pos, inp in enumerate(states):
if inp in graph_inputs([Xt]): if inp in graph_inputs([Xt]):
# Get the index of the outer output that to which # Get the index of the outer output that to which
...@@ -2555,35 +2559,39 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2555,35 +2559,39 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
new_dtype = config.floatX new_dtype = config.floatX
dC_dXt = safe_new(Xt, dtype=new_dtype) dC_dXt = safe_new(Xt, dtype=new_dtype)
else: else:
if isinstance(dC_douts[idx].type, DisconnectedType): # nit-sot outputs
# If not disconnected assume the output gradient type is a valid type for the input gradient
if isinstance(
dC_douts[idx - n_extra_mit_mot_outs].type, DisconnectedType
):
continue continue
dC_dXt = safe_new(dC_douts[idx][0]) dC_dXt = safe_new(dC_douts[idx - n_extra_mit_mot_outs][0])
dC_dXts.append(dC_dXt) dC_dXts.append(dC_dXt)
# Handle cases where the very same variable may be used as different outputs
# TODO: Couldn't we add a view Op to avoid this when building the Scan graph?
known_grads = {} known_grads = {}
dc_dxts_idx = 0 dc_dxts_idx = 0
for i in range(len(diff_outputs)): for i in range(len(diff_outputs)):
if i < idx_nitsot_start or i >= idx_nitsot_end: if not (i < idx_nitsot_out_start or i >= idx_nitsot_out_end) and isinstance(
if diff_outputs[i] in known_grads: dC_douts[i - n_extra_mit_mot_outs].type, DisconnectedType
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx] ):
else: # Special case where we don't have a dC_dXt for disconnected nitsot outputs
known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] continue
dc_dxts_idx += 1
# Just some trouble to avoid a +0
if diff_outputs[i] in known_grads:
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
else: else:
if isinstance(dC_douts[i].type, DisconnectedType): known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
continue dc_dxts_idx += 1
else:
if diff_outputs[i] in known_grads:
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
else:
known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
dc_dxts_idx += 1
dC_dinps_t = compute_all_gradients(known_grads) dC_dinps_t = compute_all_gradients(known_grads)
# mask inputs that get no gradients # mask inputs that get no gradients
for dx in range(len(dC_dinps_t)): for dx in range(len(dC_dinps_t)):
if not dC_dinps_t[dx]: if dC_dinps_t[dx] is None:
dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx]) dC_dinps_t[dx] = dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx])
else: else:
disconnected_dC_dinps_t[dx] = False disconnected_dC_dinps_t[dx] = False
for Xt, Xt_placeholder in zip( for Xt, Xt_placeholder in zip(
...@@ -2846,7 +2854,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2846,7 +2854,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for idx in range(info.n_sit_sot): for idx in range(info.n_sit_sot):
mitmot_inp_taps.append([0, 1]) mitmot_inp_taps.append([0, 1])
mitmot_out_taps.append([1]) mitmot_out_taps.append([1])
through_shared = False
if not isinstance(dC_douts[idx + offset].type, DisconnectedType): if not isinstance(dC_douts[idx + offset].type, DisconnectedType):
outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
else: else:
...@@ -3007,9 +3014,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3007,9 +3014,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
name=f"grad_of_{self.name}" if self.name else None, name=f"grad_of_{self.name}" if self.name else None,
allow_gc=self.allow_gc, allow_gc=self.allow_gc,
) )
outputs = local_op(*outer_inputs) outputs = local_op(*outer_inputs, return_list=True)
if not isinstance(outputs, list | tuple):
outputs = [outputs]
# Re-order the gradients correctly # Re-order the gradients correctly
gradients = [DisconnectedType()()] gradients = [DisconnectedType()()]
...@@ -3095,7 +3100,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3095,7 +3100,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
) )
) )
start = len(gradients)
gradients += [DisconnectedType()() for _ in range(info.n_nit_sot)] gradients += [DisconnectedType()() for _ in range(info.n_nit_sot)]
begin = end begin = end
......
...@@ -2128,6 +2128,57 @@ class TestScan: ...@@ -2128,6 +2128,57 @@ class TestScan:
# TODO: We should test something about the Rop! # TODO: We should test something about the Rop!
Rop(d_cost_wrt_pars, pars, p) Rop(d_cost_wrt_pars, pars, p)
def test_second_derivative_disconnected_cost_with_mit_mot(self):
# This test is a regression test for a bug that was revealed
# when we computed the pushforward of a Scan gradient via two applications of pullback
seq = pt.vector("seq", shape=(2,))
z = pt.scalar("z")
x0 = pt.vector("x0", shape=(2,))
# When s is 1 and z is 2, xs[-1] is just a sneaky
# x ** 4 (after two nsteps)
# grad should be 4 * x ** 3
# and grad of grad should be 12 * x ** 2
def step(s, xtm2, xtm1, z):
return s * ((xtm2 * 0 + xtm1) ** 2) * (z / 2)
xs, _ = scan(
step,
sequences=[seq],
outputs_info=[{"initial": x0, "taps": (-2, -1)}],
non_sequences=[z],
n_steps=2,
)
last_x = xs[-1]
g_wrt_x0, g_wrt_z, g_wrt_seq = pt.grad(last_x, [x0, z, seq])
g = g_wrt_x0.sum() + g_wrt_z.sum() * 0 + g_wrt_seq.sum() * 0
assert g.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 4
gg = pt.grad(g, wrt=x0).sum()
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12
assert gg.eval({seq: [2, 2], x0: [1, 1], z: 2}) == 96
# Leave out z
g_wrt_x0, g_wrt_seq = pt.grad(last_x, [x0, seq])
g = g_wrt_x0.sum() + g_wrt_seq.sum() * 0
gg = pt.grad(g, wrt=x0).sum()
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12
assert gg.eval({seq: [2, 2], x0: [1, 1], z: 2}) == 96
# Leave out seq
g_wrt_x0, g_wrt_z = pt.grad(last_x, [x0, z])
g = g_wrt_x0.sum() + g_wrt_z.sum() * 0
gg = pt.grad(g, wrt=x0).sum()
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 1}) == 3 / 2
# Leave out z and seq
g_wrt_x0 = pt.grad(last_x, x0)
g = g_wrt_x0.sum()
gg = pt.grad(g, wrt=x0).sum()
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 1}) == 3 / 2
@pytest.mark.skipif( @pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test." not config.cxx, reason="G++ not available, so we need to skip this test."
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论