提交 2ee85105 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove references to Op and Apply objects in Scan's Cython code

上级 4ca744f0
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -1362,19 +1362,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1362,19 +1362,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
try: try:
if impl == "py": if impl == "py":
raise MissingGXX raise MissingGXX
cython_mintaps = np.asarray(self.mintaps, dtype="int32") cython_mintaps = np.asarray(self.mintaps, dtype="int32")
cython_tap_array_len = np.asarray(
[len(x) for x in self.tap_array], dtype="int32" tap_array_len = tuple(len(x) for x in self.tap_array)
)
if len(self.tap_array) == 0:
d1 = 0
else:
d1 = np.max(cython_tap_array_len)
d0 = len(self.tap_array)
cython_tap_array = np.zeros((d0, d1), dtype="int32")
for _d0 in range(d0):
for _d1 in range(cython_tap_array_len[_d0]):
cython_tap_array[_d0, _d1] = self.tap_array[_d0][_d1]
cython_mit_mot_out_nslices = np.asarray( cython_mit_mot_out_nslices = np.asarray(
[len(x) for x in self.mit_mot_out_slices], dtype="int32" [len(x) for x in self.mit_mot_out_slices], dtype="int32"
) )
...@@ -1411,10 +1403,19 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1411,10 +1403,19 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_input_storage = [s.storage for s in self.fn.input_storage] inner_input_storage = [s.storage for s in self.fn.input_storage]
inner_output_storage = [s.storage for s in self.fn.output_storage] inner_output_storage = [s.storage for s in self.fn.output_storage]
inner_input_needs_update = [
inp.update is not None for inp in self.fn.maker.expanded_inputs
]
output_dtypes = [getattr(out, "dtype", None) for out in node.outputs]
from . import scan_perform_ext from . import scan_perform_ext
def p(node, inputs, outputs): def p(node, inputs, outputs):
return scan_perform_ext.perform(
t0_call = time.perf_counter()
t_fn = scan_perform_ext.perform(
self.n_shared_outs, self.n_shared_outs,
self.n_mit_mot_outs, self.n_mit_mot_outs,
self.n_seqs, self.n_seqs,
...@@ -1424,8 +1425,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1424,8 +1425,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.n_nit_sot, self.n_nit_sot,
self.as_while, self.as_while,
cython_mintaps, cython_mintaps,
cython_tap_array, self.tap_array,
cython_tap_array_len, tap_array_len,
cython_vector_seqs, cython_vector_seqs,
cython_vector_outs, cython_vector_outs,
cython_mit_mot_out_slices, cython_mit_mot_out_slices,
...@@ -1434,14 +1435,27 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1434,14 +1435,27 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_outs_is_tensor, cython_outs_is_tensor,
inner_input_storage, inner_input_storage,
inner_output_storage, inner_output_storage,
getattr(self.fn.fn, "need_update_inputs", True),
inner_input_needs_update,
self.fn, self.fn,
cython_destroy_map, cython_destroy_map,
inputs, inputs,
outputs, outputs,
self, output_dtypes,
node,
) )
t_call = time.perf_counter() - t0_call
if hasattr(self.fn.maker, "profile"):
profile = self.fn.maker.profile
if type(profile) is not bool and profile:
profile.vm_call_time += t_fn
profile.callcount += 1
profile.nbsteps += outputs[0]
profile.call_time += t_call
if hasattr(self.fn.fn, "update_profile"):
self.fn.fn.update_profile(profile)
except (ImportError, MissingGXX): except (ImportError, MissingGXX):
p = self.perform p = self.perform
......
...@@ -62,31 +62,33 @@ def get_version(): ...@@ -62,31 +62,33 @@ def get_version():
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
unsigned int n_shared_outs, unsigned int n_shared_outs,
unsigned int n_mit_mot_outs, unsigned int n_mit_mot_outs,
unsigned int n_seqs, unsigned int n_seqs,
unsigned int n_mit_mot, unsigned int n_mit_mot,
unsigned int n_mit_sot, unsigned int n_mit_sot,
unsigned int n_sit_sot, unsigned int n_sit_sot,
unsigned int n_nit_sot, unsigned int n_nit_sot,
bint as_while, bint as_while,
numpy.ndarray[numpy.int32_t,ndim=1] mintaps, numpy.ndarray[numpy.int32_t,ndim=1] mintaps,
numpy.ndarray[numpy.int32_t,ndim=2] tap_array, tuple tap_array,
numpy.ndarray[numpy.int32_t,ndim=1] tap_array_len, tuple tap_array_len,
numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs, numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs,
numpy.ndarray[numpy.int32_t,ndim=1] vector_outs, numpy.ndarray[numpy.int32_t,ndim=1] vector_outs,
numpy.ndarray[numpy.int32_t,ndim=2] mit_mot_out_slices, numpy.ndarray[numpy.int32_t,ndim=2] mit_mot_out_slices,
numpy.ndarray[numpy.int32_t,ndim=1] mitmots_preallocated, numpy.ndarray[numpy.int32_t,ndim=1] mitmots_preallocated,
numpy.ndarray[numpy.int32_t,ndim=1] inps_is_tensor, numpy.ndarray[numpy.int32_t,ndim=1] inps_is_tensor,
numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor, numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor,
list inner_input_storage, list inner_input_storage,
list inner_output_storage, list inner_output_storage,
fnct, bint need_update_inputs,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map, list inner_input_needs_update,
list outer_inputs, fnct,
list outer_outputs, numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
self, list outer_inputs,
node): list outer_outputs,
list output_dtypes,
):
""" """
Parameters Parameters
---------- ----------
...@@ -110,13 +112,13 @@ def perform( ...@@ -110,13 +112,13 @@ def perform(
away input tap from current position. For example, if the taps where [-2, away input tap from current position. For example, if the taps where [-2,
-5, -9], the mintap would be -9. For sit_sot this is always -1 since -5, -9], the mintap would be -9. For sit_sot this is always -1 since
is the only allowed tap. is the only allowed tap.
tap_array: int32 ndarray( can be replaced by a list of list in python if better) tap_array
For each of the mit_mot, mit_sot, sit_sot (the first dimension) says For each of the mit_mot, mit_sot, sit_sot (the first dimension) says
which are the corresponding input taps. While this is a matrix, not all which are the corresponding input taps. While this is a matrix, not all
values in a row are needed and tap_array_len is there to say up to values in a row are needed and tap_array_len is there to say up to
which entry we are dealing with valid taps ( afterwards there are which entry we are dealing with valid taps ( afterwards there are
just 0s to ensure the fix format) just 0s to ensure the fix format)
tap_array_len: int32 ndarray( can be replaced by a list if better) tap_array_len
For each of the mit_mot, mit_sot, sit_sot says how many input taps For each of the mit_mot, mit_sot, sit_sot says how many input taps
each has. For sit_sot this will always be 1. each has. For sit_sot this will always be 1.
vector_seqs: int32 ndarray (can be replaced by a list of bools if better) vector_seqs: int32 ndarray (can be replaced by a list of bools if better)
...@@ -138,6 +140,10 @@ def perform( ...@@ -138,6 +140,10 @@ def perform(
The storage locations for the inner-function's inputs. The storage locations for the inner-function's inputs.
inner_output_storage inner_output_storage
The storage locations for the inner-function's outputs. The storage locations for the inner-function's outputs.
need_update_inputs
A boolean indicating whether or not inner inputs need to be updated.
inner_input_needs_update
A list of booleans indicating which inner inputs need to be updated.
fnct: Function fnct: Function
The compiled Aesara inner-function object. The compiled Aesara inner-function object.
destroy_map destroy_map
...@@ -149,15 +155,13 @@ def perform( ...@@ -149,15 +155,13 @@ def perform(
This is where we need to copy our outputs ( we don't return the This is where we need to copy our outputs ( we don't return the
results, though we can change the code such that we return, and results, though we can change the code such that we return, and
figure things out on the outside - python) figure things out on the outside - python)
self: python object output_dtypes
The scan op itself. I only use it to attach to it some timing The dtypes for each output.
information .. but I don;t need to.
""" """
# 1. Unzip the number of steps and sequences. If number of steps is # 1. Unzip the number of steps and sequences. If number of steps is
# negative flip sequences around, and make n_steps positive # negative flip sequences around, and make n_steps positive
t0_call = time.time() cdef unsigned int t_fn = 0
t_fn = 0
cdef unsigned int n_steps = outer_inputs[0].item() cdef unsigned int n_steps = outer_inputs[0].item()
cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot
cdef unsigned int seqs_arg_offset = n_seqs + 1 cdef unsigned int seqs_arg_offset = n_seqs + 1
...@@ -191,7 +195,6 @@ def perform( ...@@ -191,7 +195,6 @@ def perform(
n_sit_sot + n_nit_sot + n_sit_sot + n_nit_sot +
n_shared_outs) n_shared_outs)
if n_steps < 0: if n_steps < 0:
# History, in the past, this was used for backward # History, in the past, this was used for backward
# scan. Now we reverse the inputs outside of scan. # scan. Now we reverse the inputs outside of scan.
...@@ -249,7 +252,7 @@ def perform( ...@@ -249,7 +252,7 @@ def perform(
# (The answer is that you shouldn't have a `node` object to # (The answer is that you shouldn't have a `node` object to
# access, because it's not going to produce a very efficient # access, because it's not going to produce a very efficient
# Cython function!) # Cython function!)
outer_outputs[idx][0] = node.outputs[idx].type.value_zeros(0) outer_outputs[idx][0] = numpy.zeros(0, dtype=output_dtypes[idx])
else: else:
outer_outputs[idx][0] = None outer_outputs[idx][0] = None
return return
...@@ -297,14 +300,14 @@ def perform( ...@@ -297,14 +300,14 @@ def perform(
for idx in range(n_outs): for idx in range(n_outs):
if vector_outs[idx] == 1: if vector_outs[idx] == 1:
for tdx in range(tap_array_len[idx]): for tdx in range(tap_array_len[idx]):
tap = tap_array[idx,tdx] tap = tap_array[idx][tdx]
_idx = (pos[idx]+tap)%store_steps[idx] _idx = (pos[idx]+tap)%store_steps[idx]
inner_input_storage[offset][0] =\ inner_input_storage[offset][0] =\
outer_outputs[idx][0][_idx:<unsigned int>(_idx+1)].reshape(()) outer_outputs[idx][0][_idx:<unsigned int>(_idx+1)].reshape(())
offset += 1 offset += 1
else: else:
for tdx in range(tap_array_len[idx]): for tdx in range(tap_array_len[idx]):
tap = tap_array[idx,tdx] tap = tap_array[idx][tdx]
_idx = (pos[idx]+tap)%store_steps[idx] _idx = (pos[idx]+tap)%store_steps[idx]
inner_input_storage[offset][0] = outer_outputs[idx][0][_idx] inner_input_storage[offset][0] = outer_outputs[idx][0][_idx]
offset += 1 offset += 1
...@@ -416,20 +419,19 @@ def perform( ...@@ -416,20 +419,19 @@ def perform(
dt_fn = time.time() - t0_fn dt_fn = time.time() - t0_fn
t_fn += dt_fn t_fn += dt_fn
if self.as_while: if as_while:
pdx = offset + n_shared_outs pdx = offset + n_shared_outs
cond = inner_output_storage[pdx][0] == 0 cond = inner_output_storage[pdx][0] == 0
# 5.2. By calling fn() directly instead of calling the aesara # 5.2. By calling fn() directly instead of calling the aesara
# function, it is possible that the updates have not been # function, it is possible that the updates have not been
# performed. Perform the updates if needed. # performed. Perform the updates if needed.
offset_out = len(inner_output_storage) - 1 if need_update_inputs:
if getattr(fn, 'need_update_inputs', True): offset_out = len(inner_output_storage) - 1
# Update the inputs that have an update function for needs_update, storage in zip(inner_input_needs_update[::-1],
for inp, storage in zip(self.fn.maker.expanded_inputs[::-1], inner_input_storage[::-1]):
self.fn.input_storage[::-1]): if needs_update:
if inp.update is not None: storage[0] = inner_output_storage[offset_out][0]
storage.data = inner_output_storage[offset_out][0].data
offset_out -= 1 offset_out -= 1
offset_out = 0 offset_out = 0
...@@ -437,12 +439,11 @@ def perform( ...@@ -437,12 +439,11 @@ def perform(
# 5.3 Copy over the values for mit_mot outputs # 5.3 Copy over the values for mit_mot outputs
mitmot_inp_offset = 0 mitmot_inp_offset = 0
mitmot_out_idx = 0 mitmot_out_idx = 0
for j in xrange(self.n_mit_mot): for j in xrange(n_mit_mot):
for k in self.mit_mot_out_slices[j]: for k in mit_mot_out_slices[j]:
if mitmots_preallocated[<unsigned int>mitmot_out_idx]: if mitmots_preallocated[<unsigned int>mitmot_out_idx]:
# This output tap has been preallocated. # This output tap has been preallocated.
inp_idx = (mitmot_inp_offset + inp_idx = (mitmot_inp_offset + tap_array[j].index(k))
self.tap_array[j].index(k))
# Verify whether the input points to the same data as # Verify whether the input points to the same data as
# it did before the execution of the inner function. # it did before the execution of the inner function.
...@@ -473,7 +474,7 @@ def perform( ...@@ -473,7 +474,7 @@ def perform(
mitmot_out_idx += 1 mitmot_out_idx += 1
mitmot_inp_offset += len(self.tap_array[j]) mitmot_inp_offset += len(tap_array[j])
# 5.4 Copy over the values for mit_sot/sit_sot outputs # 5.4 Copy over the values for mit_sot/sit_sot outputs
begin = n_mit_mot begin = n_mit_mot
...@@ -519,7 +520,7 @@ def perform( ...@@ -519,7 +520,7 @@ def perform(
outer_outputs[j][0].shape[0] < store_steps[j] or outer_outputs[j][0].shape[0] < store_steps[j] or
outer_outputs[j][0].shape[1:] != shape[1:] or outer_outputs[j][0].shape[1:] != shape[1:] or
outer_outputs[j][0].dtype != dtype ): outer_outputs[j][0].dtype != dtype ):
outer_outputs[j][0] = node.outputs[j].type.value_zeros(shape) outer_outputs[j][0] = numpy.zeros(shape, dtype=output_dtypes[j])
elif outer_outputs[j][0].shape[0] != store_steps[j]: elif outer_outputs[j][0].shape[0] != store_steps[j]:
outer_outputs[j][0] = outer_outputs[j][0][:store_steps[j]] outer_outputs[j][0] = outer_outputs[j][0][:store_steps[j]]
outer_outputs[j][0][pos[j]] = inner_output_storage[jout][0] outer_outputs[j][0][pos[j]] = inner_output_storage[jout][0]
...@@ -581,23 +582,23 @@ def perform( ...@@ -581,23 +582,23 @@ def perform(
# This way, there will be no information overwritten # This way, there will be no information overwritten
# before it is read (as it used to happen). # before it is read (as it used to happen).
shape = (pdx,)+ outer_outputs[idx][0].shape[1:] shape = (pdx,)+ outer_outputs[idx][0].shape[1:]
tmp = numpy.zeros(shape, dtype=output_dtypes[idx])
tmp = node.outputs[idx].type.value_zeros(shape)
tmp[:] = outer_outputs[idx][0][:pdx] tmp[:] = outer_outputs[idx][0][:pdx]
outer_outputs[idx][0][:store_steps[idx]-pdx] = outer_outputs[idx][0][pdx:] outer_outputs[idx][0][:store_steps[idx]-pdx] = outer_outputs[idx][0][pdx:]
outer_outputs[idx][0][store_steps[idx]-pdx:] = tmp outer_outputs[idx][0][store_steps[idx]-pdx:] = tmp
else: else:
shape = (store_steps[idx]-pdx,) + outer_outputs[idx][0].shape[1:] shape = (store_steps[idx]-pdx,) + outer_outputs[idx][0].shape[1:]
tmp = node.outputs[idx].type.value_zeros(shape) tmp = numpy.zeros(shape, dtype=output_dtypes[idx])
tmp[:] = outer_outputs[idx][0][pdx:] tmp[:] = outer_outputs[idx][0][pdx:]
outer_outputs[idx][0][store_steps[idx]-pdx:] = outer_outputs[idx][0][:pdx] outer_outputs[idx][0][store_steps[idx]-pdx:] = outer_outputs[idx][0][:pdx]
outer_outputs[idx][0][:store_steps[idx]-pdx] = tmp outer_outputs[idx][0][:store_steps[idx]-pdx] = tmp
# This would normally happen only when doing truncated # This would normally happen only when doing truncated
# backpropagation through time. In such a scenario Scan is # backpropagation through time. In such a scenario Scan is
# expected to return 0 for all entries for which the gradient is # expected to return 0 for all entries for which the gradient is
# not actually computed # not actually computed
elif store_steps[idx] > i - self.mintaps[idx]: elif store_steps[idx] > i - mintaps[idx]:
outer_outputs[idx][0][i-self.mintaps[idx]:] = 0 outer_outputs[idx][0][i - mintaps[idx]:] = 0
# This is a fix for a bug introduced by while. If you say # This is a fix for a bug introduced by while. If you say
# you want to loop up to a condition, you expect the output # you want to loop up to a condition, you expect the output
...@@ -623,17 +624,4 @@ def perform( ...@@ -623,17 +624,4 @@ def perform(
for s in inner_output_storage: for s in inner_output_storage:
s[0] = None s[0] = None
t_call = time.time() - t0_call return t_fn
if hasattr(fnct.maker, 'profile'):
profile = fnct.maker.profile
if type(profile) is not bool and profile:
profile.vm_call_time += t_fn
profile.callcount += 1
profile.nbsteps += n_steps
profile.call_time += t_call
if hasattr(fn, 'update_profile'):
fn.update_profile(profile)
self.t_call = t_call
self.t_fn = t_fn
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论