提交 0b97f658 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Stop inner-graph output reordering in Scan

In order to use shared updates to pre-allocate the storage for mit-mot input and output loops, `Scan` would need to remove the corresponding mit-mot outputs from its inner-`FunctionGraph` before compilation and it would expect the `Function` compilation pipeline to add them back at the end of the remaining outputs. Now, `Scan`'s inner-`FunctionGraph`s maintain the same form at every point, and no special logic is needed to compensate for post-compilation changes in the order/location of inputs and outputs.
上级 27d2bfe3
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -861,30 +861,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -861,30 +861,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
self.n_tap_outs = info.n_mit_mot + info.n_mit_sot self.n_tap_outs = info.n_mit_mot + info.n_mit_sot
# TODO: These can be moved to thunk/function compilation
( (
self.preallocated_mitmot_outs, _,
self.mitmots_preallocated, self.mitmots_preallocated,
) = self._mitmot_preallocations() ) = self._mitmot_preallocations()
self.n_outer_inputs = info.n_outer_inputs self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs self.n_outer_outputs = info.n_outer_outputs
features = [] self.fgraph = FunctionGraph(inputs, outputs, clone=False)
if config.scan__allow_output_prealloc: _ = self.prepare_fgraph(self.fgraph)
# This feature will prevent mitsot, sitsot and nitsot outputs from
# being computed inplace (to allow their preallocation).
mitsot_start = info.n_mit_mot_outs - len(self.preallocated_mitmot_outs)
nitsot_end = mitsot_start + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot
features.append(NoOutputFromInplace(range(mitsot_start, nitsot_end)))
self.fgraph = FunctionGraph(
inputs,
outputs,
clone=False,
features=features,
)
if any(node.op.destroy_map for node in self.fgraph.apply_nodes): if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
raise InconsistencyError( raise InconsistencyError(
...@@ -1360,23 +1348,21 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1360,23 +1348,21 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
) )
) )
@property def prepare_fgraph(self, fgraph):
def fn(self): """Update and wrap `fgraph`'s inputs and outputs in preparation for compilation."""
"""Lazily compile the inner function graph."""
if getattr(self, "_fn", None) is not None:
return self._fn
# If a shared variable is the result of a ViewOp it is a clear info = self.info
# indication that we need to copy that value after the perform of slices = info.n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot
# scan is done
slices = (
self.info.n_mit_mot_outs
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_nit_sot
)
fgraph = self.fgraph.clone() # Setting `fgraph.update_mapping` will indicate to the `Function`
# construction pipeline that it needn't append the updates to the
# `FunctionGraph` outputs itself, because they're already in the given
# `FunctionGraph`'s outputs. This also prevents us from needing to
# remove those outputs here just to compensate for an overly rigid
# `Function` pipeline.
update_mapping = {}
preallocated_mitmot_outs = []
if config.scan__allow_output_prealloc: if config.scan__allow_output_prealloc:
...@@ -1384,32 +1370,31 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1384,32 +1370,31 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# input and an output, wrap the input such that the corresponding # input and an output, wrap the input such that the corresponding
# output variable becomes an update to be performed on it, possibly # output variable becomes an update to be performed on it, possibly
# inplace at the end of the functions's execution. # inplace at the end of the functions's execution.
wrapped_inputs = [ wrapped_inputs = [In(x, borrow=False) for x in fgraph.inputs[: info.n_seqs]]
In(x, borrow=False) for x in fgraph.inputs[: self.info.n_seqs]
]
new_outputs = [x for x in fgraph.outputs]
input_idx = self.info.n_seqs input_idx = info.n_seqs
for mitmot_idx in range(self.info.n_mit_mot): for mitmot_idx in range(info.n_mit_mot):
for inp_tap in self.info.mit_mot_in_slices[mitmot_idx]: for inp_tap in info.mit_mot_in_slices[mitmot_idx]:
if inp_tap in self.info.mit_mot_out_slices[mitmot_idx]: if inp_tap in info.mit_mot_out_slices[mitmot_idx]:
inp = fgraph.inputs[input_idx] inp = fgraph.inputs[input_idx]
# Figure out the index of the corresponding output # Figure out the index of the corresponding output
output_idx = sum( output_idx = sum(
len(m) for m in self.info.mit_mot_out_slices[:mitmot_idx] len(m) for m in info.mit_mot_out_slices[:mitmot_idx]
) )
output_idx += self.info.mit_mot_out_slices[mitmot_idx].index( output_idx += self.info.mit_mot_out_slices[mitmot_idx].index(
inp_tap inp_tap
) )
# Make it so the input is automatically updated to the preallocated_mitmot_outs.append(output_idx)
# output value, possibly inplace, at the end of the
# function execution. Also, since an update is # Make it so that the input is automatically updated to
# defined, a default value must also be (this is # the output value, possibly inplace, at the end of the
# verified by DebugMode). Use an array of size 0 but # function execution. Also, since an update is defined,
# the right ndim and dtype (use a shape of 1 on # a default value must also be (this is verified by
# broadcastable dimensions, 0 on the others). # DebugMode). Use an array of size 0 with the correct
# ndim and dtype (use a shape of 1 on broadcastable
# dimensions, and 0 on the others).
default_shape = [1 if _b else 0 for _b in inp.broadcastable] default_shape = [1 if _b else 0 for _b in inp.broadcastable]
default_val = inp.type.value_zeros(default_shape) default_val = inp.type.value_zeros(default_shape)
wrapped_inp = In( wrapped_inp = In(
...@@ -1417,10 +1402,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1417,10 +1402,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
value=default_val, value=default_val,
update=fgraph.outputs[output_idx], update=fgraph.outputs[output_idx],
) )
update_mapping[output_idx] = input_idx
wrapped_inputs.append(wrapped_inp) wrapped_inputs.append(wrapped_inp)
else: else:
# Wrap the corresponding input as usual. Leave the
# output as-is.
wrapped_inputs.append( wrapped_inputs.append(
In(fgraph.inputs[input_idx], borrow=False) In(fgraph.inputs[input_idx], borrow=False)
) )
...@@ -1429,22 +1413,57 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1429,22 +1413,57 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Wrap the inputs not associated to mitmots and wrap the remaining # Wrap the inputs not associated to mitmots and wrap the remaining
# outputs # outputs
wrapped_inputs += [In(x, borrow=False) for x in fgraph.inputs[input_idx:]] wrapped_inputs += [In(x, borrow=False) for x in fgraph.inputs[input_idx:]]
wrapped_outputs = [Out(x, borrow=True) for x in new_outputs[:slices]] wrapped_outputs = [Out(x, borrow=True) for x in fgraph.outputs[:slices]]
wrapped_outputs += new_outputs[slices:] wrapped_outputs += fgraph.outputs[slices:]
# Remove now useless outputs from the output list and start from protected_outs = tuple(
# the end to avoid altering the indices of the other outputs to be i
# deleted. for i in range(
for p in self.preallocated_mitmot_outs[::-1]: info.n_mit_mot_outs
fgraph.remove_output(p, reason="scan_prealloc") + info.n_mit_sot
del wrapped_outputs[p] + info.n_sit_sot
+ info.n_nit_sot
)
if i not in preallocated_mitmot_outs
)
fgraph.attach_feature(NoOutputFromInplace(protected_outs))
else: else:
wrapped_inputs = [In(x, borrow=True) for x in fgraph.inputs] wrapped_inputs = [In(x, borrow=True) for x in fgraph.inputs]
wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs[:slices]] wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs[:slices]]
wrapped_outputs += fgraph.outputs[slices:] wrapped_outputs += fgraph.outputs[slices:]
fgraph.update_mapping = update_mapping
from aesara.compile.function.types import Supervisor
from aesara.graph.destroyhandler import DestroyHandler
for node in fgraph.apply_nodes:
if node.op.destroy_map:
fgraph.attach_feature(DestroyHandler())
break
fgraph.attach_feature(
Supervisor(
inp
for spec, inp in zip(wrapped_inputs, fgraph.inputs)
if not (
getattr(spec, "mutable", None)
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([inp]))
)
)
)
return wrapped_inputs, wrapped_outputs
@property
def fn(self):
"""Lazily compile the inner function graph."""
if getattr(self, "_fn", None) is not None:
return self._fn
wrapped_inputs, wrapped_outputs = self.prepare_fgraph(self.fgraph)
profile = None profile = None
if config.profile or ( if config.profile or (
isinstance(self.profile, (str, bool, (int,))) and self.profile isinstance(self.profile, (str, bool, (int,))) and self.profile
...@@ -1463,7 +1482,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1463,7 +1482,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
accept_inplace=False, accept_inplace=False,
profile=profile, profile=profile,
on_unused_input="ignore", on_unused_input="ignore",
fgraph=fgraph, fgraph=self.fgraph,
) )
return self._fn return self._fn
...@@ -1559,10 +1578,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1559,10 +1578,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_input_storage = [s.storage for s in self.fn.input_storage] inner_input_storage = [s.storage for s in self.fn.input_storage]
inner_output_storage = [s.storage for s in self.fn.output_storage] inner_output_storage = [s.storage for s in self.fn.output_storage]
inner_input_needs_update = tuple(
inp.update is not None for inp in self.fn.maker.expanded_inputs
)
outer_output_dtypes = tuple( outer_output_dtypes = tuple(
getattr(out, "dtype", None) for out in node.outputs getattr(out, "dtype", None) for out in node.outputs
) )
...@@ -1607,8 +1622,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1607,8 +1622,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_outs_is_tensor, cython_outs_is_tensor,
inner_input_storage, inner_input_storage,
inner_output_storage, inner_output_storage,
getattr(self.fn.fn, "need_update_inputs", True),
inner_input_needs_update,
cython_destroy_map, cython_destroy_map,
inputs, inputs,
outputs, outputs,
...@@ -1715,10 +1728,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1715,10 +1728,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
seqs.append(seq) seqs.append(seq)
# 2. Allocate memory for the outputs. Construct the list: # 2. Allocate memory for the outputs. Construct the list:
# store_steps -- map containing the length of each output
# pos -- map containing the current position of each
# output
# The length of each output
store_steps = [ store_steps = [
arg.shape[0] arg.shape[0]
for arg in inputs[self.seqs_arg_offset : self.shared_arg_offset] for arg in inputs[self.seqs_arg_offset : self.shared_arg_offset]
...@@ -1763,6 +1774,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1763,6 +1774,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
output_storage[idx][0] = None output_storage[idx][0] = None
return return
# The current position of each output
pos = [ pos = [
(-self.mintaps[idx]) % store_steps[idx] (-self.mintaps[idx]) % store_steps[idx]
for idx in range(self.n_outs + info.n_nit_sot) for idx in range(self.n_outs + info.n_nit_sot)
...@@ -1950,37 +1962,25 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1950,37 +1962,25 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
pdx = offset + info.n_shared_outs pdx = offset + info.n_shared_outs
cond = inner_output_storage[pdx].storage[0] == 0 cond = inner_output_storage[pdx].storage[0] == 0
# 5.2. By calling fn() directly instead of calling the aesara
# function, it is possible that the updates have not been
# performed. Perform the updates if needed.
offset_out = len(inner_output_storage) - 1
if getattr(fn, "need_update_inputs", True):
# Update the inputs that have an update function
for inp, storage in zip(
self.fn.maker.expanded_inputs[::-1], self.fn.input_storage[::-1]
):
if inp.update is not None:
storage.data = inner_output_storage[offset_out].data
offset_out -= 1
t_fn += dt_fn t_fn += dt_fn
offset_out = 0 offset_out = 0
# 5.3 Copy over the values for mit_mot outputs # 5.3 Copy over the values for mit_mot outputs
mitmot_inp_offset = 0 mitmot_inp_grp_offset = 0
mitmot_out_idx = 0 mitmot_out_idx = 0
for j, taps in enumerate(info.mit_mot_in_slices): for mitmot_grp_idx, taps in enumerate(info.mit_mot_in_slices):
for k in info.mit_mot_out_slices[j]: for out_slice in info.mit_mot_out_slices[mitmot_grp_idx]:
if self.mitmots_preallocated[mitmot_out_idx]: if self.mitmots_preallocated[mitmot_out_idx]:
# This output tap has been preallocated.
inp_idx = mitmot_inp_offset + taps.index(k) mitmot_inp_idx = mitmot_inp_grp_offset + taps.index(out_slice)
inner_inp_idx = self.n_seqs + mitmot_inp_idx
# Verify whether the input points to the same data as # Verify whether the input points to the same data as
# it did before the execution of the inner function. # it did before the execution of the inner function.
old_var = old_mitmot_input_storage[inp_idx] old_var = old_mitmot_input_storage[mitmot_inp_idx]
new_var = inner_input_storage[info.n_seqs + inp_idx].storage[0] new_var = inner_input_storage[inner_inp_idx].storage[0]
if old_var is new_var: if old_var is new_var:
old_data = old_mitmot_input_data[inp_idx] old_data = old_mitmot_input_data[mitmot_inp_idx]
same_data = new_var.data == old_data same_data = new_var.data == old_data
else: else:
same_data = False same_data = False
...@@ -1990,21 +1990,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1990,21 +1990,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# nothing needs to be done. Otherwise, recover the # nothing needs to be done. Otherwise, recover the
# and store it in `outs` as usual # and store it in `outs` as usual
if not same_data: if not same_data:
output_storage[j][0][k + pos[j]] = inner_input_storage[ output_storage[mitmot_grp_idx][0][
info.n_seqs + inp_idx out_slice + pos[mitmot_grp_idx]
].storage[0] ] = inner_input_storage[inner_inp_idx].storage[0]
else: else:
# This output tap has not been preallocated, recover # This output tap has not been preallocated, recover
# its value as usual # its value as usual
output_storage[j][0][k + pos[j]] = inner_output_storage[ output_storage[mitmot_grp_idx][0][
offset_out out_slice + pos[mitmot_grp_idx]
].storage[0] ] = inner_output_storage[offset_out].storage[0]
offset_out += 1
offset_out += 1
mitmot_out_idx += 1 mitmot_out_idx += 1
mitmot_inp_offset += len(taps) mitmot_inp_grp_offset += len(taps)
# 5.4 Copy over the values for mit_sot/sit_sot outputs # 5.4 Copy over the values for mit_sot/sit_sot outputs
begin = info.n_mit_mot begin = info.n_mit_mot
......
...@@ -59,7 +59,7 @@ from aesara.scan.utils import InnerFunctionError ...@@ -59,7 +59,7 @@ from aesara.scan.utils import InnerFunctionError
def get_version(): def get_version():
return 0.313 return 0.314
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -81,8 +81,6 @@ def perform( ...@@ -81,8 +81,6 @@ def perform(
numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor, numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor,
list inner_input_storage, list inner_input_storage,
list inner_output_storage, list inner_output_storage,
bint need_update_inputs,
tuple inner_input_needs_update,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map, numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
list outer_inputs, list outer_inputs,
list outer_outputs, list outer_outputs,
...@@ -138,10 +136,6 @@ def perform( ...@@ -138,10 +136,6 @@ def perform(
The storage locations for the inner-function's inputs. The storage locations for the inner-function's inputs.
inner_output_storage inner_output_storage
The storage locations for the inner-function's outputs. The storage locations for the inner-function's outputs.
need_update_inputs
A boolean indicating whether or not inner inputs need to be updated.
inner_input_needs_update
A tuple of booleans indicating which inner inputs need to be updated.
fnct: Function fnct: Function
The compiled Aesara inner-function object. The compiled Aesara inner-function object.
destroy_map destroy_map
...@@ -173,8 +167,13 @@ def perform( ...@@ -173,8 +167,13 @@ def perform(
n_shared_outs) n_shared_outs)
cdef unsigned int offset_out cdef unsigned int offset_out
cdef unsigned int lenpos = n_outs + n_nit_sot cdef unsigned int lenpos = n_outs + n_nit_sot
# TODO: See how this is being converted and whether or not we can remove
# fixed allocations caused by this.
cdef int pos[500] # put a maximum of 500 outputs cdef int pos[500] # put a maximum of 500 outputs
cdef unsigned int len_store_steps = n_mit_mot + n_mit_sot + n_sit_sot + n_nit_sot cdef unsigned int len_store_steps = n_mit_mot + n_mit_sot + n_sit_sot + n_nit_sot
# The length of each output
# TODO: See how this is being converted and whether or not we can remove
# fixed allocations caused by this.
cdef int store_steps[500] cdef int store_steps[500]
cdef unsigned int l cdef unsigned int l
cdef unsigned int offset cdef unsigned int offset
...@@ -214,9 +213,8 @@ def perform( ...@@ -214,9 +213,8 @@ def perform(
outer_inputs[1+idx].shape, outer_inputs[1+idx].shape,
n_steps, n_steps,
)) ))
# 2. Allocate memory for the outputs. Construct the list: # 2. Allocate memory for the outputs. Construct the list:
# store_steps -- map containing the length of each output
# pos -- map containing the current position of each output
for idx in range(n_mit_mot + n_mit_sot + n_sit_sot): for idx in range(n_mit_mot + n_mit_sot + n_sit_sot):
store_steps[<unsigned int>idx] = outer_inputs[<unsigned int>(idx+n_seqs+1)].shape[0] store_steps[<unsigned int>idx] = outer_inputs[<unsigned int>(idx+n_seqs+1)].shape[0]
...@@ -400,17 +398,6 @@ def perform( ...@@ -400,17 +398,6 @@ def perform(
pdx = offset + n_shared_outs pdx = offset + n_shared_outs
cond = inner_output_storage[pdx][0] == 0 cond = inner_output_storage[pdx][0] == 0
# 5.2. By calling fn() directly instead of calling the aesara
# function, it is possible that the updates have not been
# performed. Perform the updates if needed.
if need_update_inputs:
offset_out = len(inner_output_storage) - 1
for needs_update, storage in zip(inner_input_needs_update[::-1],
inner_input_storage[::-1]):
if needs_update:
storage[0] = inner_output_storage[offset_out][0]
offset_out -= 1
offset_out = 0 offset_out = 0
# 5.3 Copy over the values for mit_mot outputs # 5.3 Copy over the values for mit_mot outputs
...@@ -421,11 +408,12 @@ def perform( ...@@ -421,11 +408,12 @@ def perform(
if mitmots_preallocated[<unsigned int>mitmot_out_idx]: if mitmots_preallocated[<unsigned int>mitmot_out_idx]:
# This output tap has been preallocated. # This output tap has been preallocated.
inp_idx = (mitmot_inp_offset + tap_array[j].index(k)) inp_idx = (mitmot_inp_offset + tap_array[j].index(k))
inner_inp_idx = n_seqs + inp_idx
# Verify whether the input points to the same data as # Verify whether the input points to the same data as
# it did before the execution of the inner function. # it did before the execution of the inner function.
old_var = old_mitmot_input_storage[inp_idx] old_var = old_mitmot_input_storage[inp_idx]
new_var = inner_input_storage[n_seqs + inp_idx][0] new_var = inner_input_storage[inner_inp_idx][0]
if old_var is new_var: if old_var is new_var:
old_data = old_mitmot_input_data[inp_idx] old_data = old_mitmot_input_data[inp_idx]
same_data = (new_var.data == old_data) same_data = (new_var.data == old_data)
...@@ -437,15 +425,15 @@ def perform( ...@@ -437,15 +425,15 @@ def perform(
# modified inplace and nothing needs to be done. # modified inplace and nothing needs to be done.
if not same_data: if not same_data:
outer_outputs[j][0][<unsigned int>(k + pos[j])] = \ outer_outputs[j][0][<unsigned int>(k + pos[j])] = \
inner_input_storage[<unsigned int>(n_seqs + inp_idx)][0] inner_input_storage[<unsigned int>(inner_inp_idx)][0]
else: else:
# This output tap has not been preallocated, recover # This output tap has not been preallocated, recover
# its value as usual # its value as usual
outer_outputs[j][0][<unsigned int>(k + pos[j])] = \ outer_outputs[j][0][<unsigned int>(k + pos[j])] = \
inner_output_storage[<unsigned int>offset_out][0] inner_output_storage[<unsigned int>offset_out][0]
offset_out += 1
offset_out += 1
mitmot_out_idx += 1 mitmot_out_idx += 1
mitmot_inp_offset += tap_array_len[j] mitmot_inp_offset += tap_array_len[j]
......
...@@ -23,7 +23,7 @@ if not config.cxx: ...@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform") _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.313 # must match constant returned in function get_version() version = 0.314 # must match constant returned in function get_version()
need_reload = False need_reload = False
scan_perform: Optional[ModuleType] = None scan_perform: Optional[ModuleType] = None
......
...@@ -843,9 +843,14 @@ def test_random_state_transfer(): ...@@ -843,9 +843,14 @@ def test_random_state_transfer():
np.testing.assert_array_almost_equal(f1(), f2(), decimal=6) np.testing.assert_array_almost_equal(f1(), f2(), decimal=6)
def test_gradient_scan(): @pytest.mark.parametrize(
# Test for a crash when using MRG inside scan and taking the gradient "mode",
# See https://groups.google.com/d/msg/theano-dev/UbcYyU5m-M8/UO9UgXqnQP0J [
"FAST_RUN",
"FAST_COMPILE",
],
)
def test_gradient_scan(mode):
aesara_rng = MRG_RandomStream(10) aesara_rng = MRG_RandomStream(10)
w = shared(np.ones(1, dtype="float32")) w = shared(np.ones(1, dtype="float32"))
...@@ -855,8 +860,12 @@ def test_gradient_scan(): ...@@ -855,8 +860,12 @@ def test_gradient_scan():
x = vector(dtype="float32") x = vector(dtype="float32")
values, updates = scan(one_step, outputs_info=x, n_steps=10) values, updates = scan(one_step, outputs_info=x, n_steps=10)
gw = grad(at_sum(values[-1]), w) gw = grad(at_sum(values[-1]), w)
f = function([x], gw) f = function([x], gw, mode=mode)
f(np.arange(1, dtype="float32")) assert np.allclose(
f(np.arange(1, dtype=np.float32)),
np.array([0.13928187], dtype=np.float32),
rtol=1e6,
)
def test_simple_shared_mrg_random(): def test_simple_shared_mrg_random():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论