提交 2ee85105 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove references to Op and Apply objects in Scan's Cython code

上级 4ca744f0
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -1362,19 +1362,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
try:
if impl == "py":
raise MissingGXX
cython_mintaps = np.asarray(self.mintaps, dtype="int32")
cython_tap_array_len = np.asarray(
[len(x) for x in self.tap_array], dtype="int32"
)
if len(self.tap_array) == 0:
d1 = 0
else:
d1 = np.max(cython_tap_array_len)
d0 = len(self.tap_array)
cython_tap_array = np.zeros((d0, d1), dtype="int32")
for _d0 in range(d0):
for _d1 in range(cython_tap_array_len[_d0]):
cython_tap_array[_d0, _d1] = self.tap_array[_d0][_d1]
tap_array_len = tuple(len(x) for x in self.tap_array)
cython_mit_mot_out_nslices = np.asarray(
[len(x) for x in self.mit_mot_out_slices], dtype="int32"
)
......@@ -1411,10 +1403,19 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_input_storage = [s.storage for s in self.fn.input_storage]
inner_output_storage = [s.storage for s in self.fn.output_storage]
inner_input_needs_update = [
inp.update is not None for inp in self.fn.maker.expanded_inputs
]
output_dtypes = [getattr(out, "dtype", None) for out in node.outputs]
from . import scan_perform_ext
def p(node, inputs, outputs):
return scan_perform_ext.perform(
t0_call = time.perf_counter()
t_fn = scan_perform_ext.perform(
self.n_shared_outs,
self.n_mit_mot_outs,
self.n_seqs,
......@@ -1424,8 +1425,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.n_nit_sot,
self.as_while,
cython_mintaps,
cython_tap_array,
cython_tap_array_len,
self.tap_array,
tap_array_len,
cython_vector_seqs,
cython_vector_outs,
cython_mit_mot_out_slices,
......@@ -1434,14 +1435,27 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_outs_is_tensor,
inner_input_storage,
inner_output_storage,
getattr(self.fn.fn, "need_update_inputs", True),
inner_input_needs_update,
self.fn,
cython_destroy_map,
inputs,
outputs,
self,
node,
output_dtypes,
)
t_call = time.perf_counter() - t0_call
if hasattr(self.fn.maker, "profile"):
profile = self.fn.maker.profile
if type(profile) is not bool and profile:
profile.vm_call_time += t_fn
profile.callcount += 1
profile.nbsteps += outputs[0]
profile.call_time += t_call
if hasattr(self.fn.fn, "update_profile"):
self.fn.fn.update_profile(profile)
except (ImportError, MissingGXX):
p = self.perform
......
......@@ -62,31 +62,33 @@ def get_version():
@cython.boundscheck(False)
def perform(
unsigned int n_shared_outs,
unsigned int n_mit_mot_outs,
unsigned int n_seqs,
unsigned int n_mit_mot,
unsigned int n_mit_sot,
unsigned int n_sit_sot,
unsigned int n_nit_sot,
bint as_while,
numpy.ndarray[numpy.int32_t,ndim=1] mintaps,
numpy.ndarray[numpy.int32_t,ndim=2] tap_array,
numpy.ndarray[numpy.int32_t,ndim=1] tap_array_len,
numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs,
numpy.ndarray[numpy.int32_t,ndim=1] vector_outs,
numpy.ndarray[numpy.int32_t,ndim=2] mit_mot_out_slices,
numpy.ndarray[numpy.int32_t,ndim=1] mitmots_preallocated,
numpy.ndarray[numpy.int32_t,ndim=1] inps_is_tensor,
numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor,
list inner_input_storage,
list inner_output_storage,
fnct,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
list outer_inputs,
list outer_outputs,
self,
node):
unsigned int n_shared_outs,
unsigned int n_mit_mot_outs,
unsigned int n_seqs,
unsigned int n_mit_mot,
unsigned int n_mit_sot,
unsigned int n_sit_sot,
unsigned int n_nit_sot,
bint as_while,
numpy.ndarray[numpy.int32_t,ndim=1] mintaps,
tuple tap_array,
tuple tap_array_len,
numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs,
numpy.ndarray[numpy.int32_t,ndim=1] vector_outs,
numpy.ndarray[numpy.int32_t,ndim=2] mit_mot_out_slices,
numpy.ndarray[numpy.int32_t,ndim=1] mitmots_preallocated,
numpy.ndarray[numpy.int32_t,ndim=1] inps_is_tensor,
numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor,
list inner_input_storage,
list inner_output_storage,
bint need_update_inputs,
list inner_input_needs_update,
fnct,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
list outer_inputs,
list outer_outputs,
list output_dtypes,
):
"""
Parameters
----------
......@@ -110,13 +112,13 @@ def perform(
away input tap from current position. For example, if the taps where [-2,
-5, -9], the mintap would be -9. For sit_sot this is always -1 since
is the only allowed tap.
tap_array: int32 ndarray( can be replaced by a list of list in python if better)
tap_array
For each of the mit_mot, mit_sot, sit_sot (the first dimension) says
which are the corresponding input taps. While this is a matrix, not all
values in a row are needed and tap_array_len is there to say up to
which entry we are dealing with valid taps ( afterwards there are
just 0s to ensure the fix format)
tap_array_len: int32 ndarray( can be replaced by a list if better)
tap_array_len
For each of the mit_mot, mit_sot, sit_sot says how many input taps
each has. For sit_sot this will always be 1.
vector_seqs: int32 ndarray (can be replaced by a list of bools if better)
......@@ -138,6 +140,10 @@ def perform(
The storage locations for the inner-function's inputs.
inner_output_storage
The storage locations for the inner-function's outputs.
need_update_inputs
A boolean indicating whether or not inner inputs need to be updated.
inner_input_needs_update
A list of booleans indicating which inner inputs need to be updated.
fnct: Function
The compiled Aesara inner-function object.
destroy_map
......@@ -149,15 +155,13 @@ def perform(
This is where we need to copy our outputs ( we don't return the
results, though we can change the code such that we return, and
figure things out on the outside - python)
self: python object
The scan op itself. I only use it to attach to it some timing
information .. but I don;t need to.
output_dtypes
The dtypes for each output.
"""
# 1. Unzip the number of steps and sequences. If number of steps is
# negative flip sequences around, and make n_steps positive
t0_call = time.time()
t_fn = 0
cdef unsigned int t_fn = 0
cdef unsigned int n_steps = outer_inputs[0].item()
cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot
cdef unsigned int seqs_arg_offset = n_seqs + 1
......@@ -191,7 +195,6 @@ def perform(
n_sit_sot + n_nit_sot +
n_shared_outs)
if n_steps < 0:
# History, in the past, this was used for backward
# scan. Now we reverse the inputs outside of scan.
......@@ -249,7 +252,7 @@ def perform(
# (The answer is that you shouldn't have a `node` object to
# access, because it's not going to produce a very efficient
# Cython function!)
outer_outputs[idx][0] = node.outputs[idx].type.value_zeros(0)
outer_outputs[idx][0] = numpy.zeros(0, dtype=output_dtypes[idx])
else:
outer_outputs[idx][0] = None
return
......@@ -297,14 +300,14 @@ def perform(
for idx in range(n_outs):
if vector_outs[idx] == 1:
for tdx in range(tap_array_len[idx]):
tap = tap_array[idx,tdx]
tap = tap_array[idx][tdx]
_idx = (pos[idx]+tap)%store_steps[idx]
inner_input_storage[offset][0] =\
outer_outputs[idx][0][_idx:<unsigned int>(_idx+1)].reshape(())
offset += 1
else:
for tdx in range(tap_array_len[idx]):
tap = tap_array[idx,tdx]
tap = tap_array[idx][tdx]
_idx = (pos[idx]+tap)%store_steps[idx]
inner_input_storage[offset][0] = outer_outputs[idx][0][_idx]
offset += 1
......@@ -416,20 +419,19 @@ def perform(
dt_fn = time.time() - t0_fn
t_fn += dt_fn
if self.as_while:
if as_while:
pdx = offset + n_shared_outs
cond = inner_output_storage[pdx][0] == 0
# 5.2. By calling fn() directly instead of calling the aesara
# function, it is possible that the updates have not been
# performed. Perform the updates if needed.
offset_out = len(inner_output_storage) - 1
if getattr(fn, 'need_update_inputs', True):
# Update the inputs that have an update function
for inp, storage in zip(self.fn.maker.expanded_inputs[::-1],
self.fn.input_storage[::-1]):
if inp.update is not None:
storage.data = inner_output_storage[offset_out][0].data
if need_update_inputs:
offset_out = len(inner_output_storage) - 1
for needs_update, storage in zip(inner_input_needs_update[::-1],
inner_input_storage[::-1]):
if needs_update:
storage[0] = inner_output_storage[offset_out][0]
offset_out -= 1
offset_out = 0
......@@ -437,12 +439,11 @@ def perform(
# 5.3 Copy over the values for mit_mot outputs
mitmot_inp_offset = 0
mitmot_out_idx = 0
for j in xrange(self.n_mit_mot):
for k in self.mit_mot_out_slices[j]:
for j in xrange(n_mit_mot):
for k in mit_mot_out_slices[j]:
if mitmots_preallocated[<unsigned int>mitmot_out_idx]:
# This output tap has been preallocated.
inp_idx = (mitmot_inp_offset +
self.tap_array[j].index(k))
inp_idx = (mitmot_inp_offset + tap_array[j].index(k))
# Verify whether the input points to the same data as
# it did before the execution of the inner function.
......@@ -473,7 +474,7 @@ def perform(
mitmot_out_idx += 1
mitmot_inp_offset += len(self.tap_array[j])
mitmot_inp_offset += len(tap_array[j])
# 5.4 Copy over the values for mit_sot/sit_sot outputs
begin = n_mit_mot
......@@ -519,7 +520,7 @@ def perform(
outer_outputs[j][0].shape[0] < store_steps[j] or
outer_outputs[j][0].shape[1:] != shape[1:] or
outer_outputs[j][0].dtype != dtype ):
outer_outputs[j][0] = node.outputs[j].type.value_zeros(shape)
outer_outputs[j][0] = numpy.zeros(shape, dtype=output_dtypes[j])
elif outer_outputs[j][0].shape[0] != store_steps[j]:
outer_outputs[j][0] = outer_outputs[j][0][:store_steps[j]]
outer_outputs[j][0][pos[j]] = inner_output_storage[jout][0]
......@@ -581,23 +582,23 @@ def perform(
# This way, there will be no information overwritten
# before it is read (as it used to happen).
shape = (pdx,)+ outer_outputs[idx][0].shape[1:]
tmp = node.outputs[idx].type.value_zeros(shape)
tmp = numpy.zeros(shape, dtype=output_dtypes[idx])
tmp[:] = outer_outputs[idx][0][:pdx]
outer_outputs[idx][0][:store_steps[idx]-pdx] = outer_outputs[idx][0][pdx:]
outer_outputs[idx][0][store_steps[idx]-pdx:] = tmp
else:
shape = (store_steps[idx]-pdx,) + outer_outputs[idx][0].shape[1:]
tmp = node.outputs[idx].type.value_zeros(shape)
tmp = numpy.zeros(shape, dtype=output_dtypes[idx])
tmp[:] = outer_outputs[idx][0][pdx:]
outer_outputs[idx][0][store_steps[idx]-pdx:] = outer_outputs[idx][0][:pdx]
outer_outputs[idx][0][:store_steps[idx]-pdx] = tmp
# This would normally happen only when doing truncated
# backpropagation through time. In such a scenario Scan is
# expected to return 0 for all entries for which the gradient is
# not actually computed
elif store_steps[idx] > i - self.mintaps[idx]:
outer_outputs[idx][0][i-self.mintaps[idx]:] = 0
elif store_steps[idx] > i - mintaps[idx]:
outer_outputs[idx][0][i - mintaps[idx]:] = 0
# This is a fix for a bug introduced by while. If you say
# you want to loop up to a condition, you expect the output
......@@ -623,17 +624,4 @@ def perform(
for s in inner_output_storage:
s[0] = None
t_call = time.time() - t0_call
if hasattr(fnct.maker, 'profile'):
profile = fnct.maker.profile
if type(profile) is not bool and profile:
profile.vm_call_time += t_fn
profile.callcount += 1
profile.nbsteps += n_steps
profile.call_time += t_call
if hasattr(fn, 'update_profile'):
fn.update_profile(profile)
self.t_call = t_call
self.t_fn = t_fn
return t_fn
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论