提交 c997333d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use np.empty instead of np.zeros to allocate memory in Scan

上级 2ee85105
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -45,10 +45,10 @@ relies on the following elements to work properly : ...@@ -45,10 +45,10 @@ relies on the following elements to work properly :
import dataclasses import dataclasses
import itertools
import logging import logging
import time import time
from collections import OrderedDict from collections import OrderedDict
from itertools import product
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -530,7 +530,7 @@ class ScanMethodsMixin: ...@@ -530,7 +530,7 @@ class ScanMethodsMixin:
inner_iidxs = var_mappings["inner_inp_from_outer_out"][outer_oidx] inner_iidxs = var_mappings["inner_inp_from_outer_out"][outer_oidx]
inner_oidxs = var_mappings["inner_out_from_outer_out"][outer_oidx] inner_oidxs = var_mappings["inner_out_from_outer_out"][outer_oidx]
for (inner_iidx, inner_oidx) in itertools.product(inner_iidxs, inner_oidxs): for (inner_iidx, inner_oidx) in product(inner_iidxs, inner_oidxs):
type_input = self.inputs[inner_iidx].type type_input = self.inputs[inner_iidx].type
type_output = self.outputs[inner_oidx].type type_output = self.outputs[inner_oidx].type
...@@ -1363,25 +1363,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1363,25 +1363,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if impl == "py": if impl == "py":
raise MissingGXX raise MissingGXX
from . import scan_perform_ext
cython_mintaps = np.asarray(self.mintaps, dtype="int32") cython_mintaps = np.asarray(self.mintaps, dtype="int32")
tap_array_len = tuple(len(x) for x in self.tap_array) tap_array_len = tuple(len(x) for x in self.tap_array)
cython_mit_mot_out_nslices = np.asarray(
[len(x) for x in self.mit_mot_out_slices], dtype="int32"
)
if len(self.mit_mot_out_slices) == 0:
d1 = 0
else:
d1 = np.max(cython_mit_mot_out_nslices)
d0 = len(self.mit_mot_out_slices)
cython_mit_mot_out_slices = np.zeros((d0, d1), dtype="int32")
for _d0 in range(d0):
for _d1 in range(cython_mit_mot_out_nslices[_d0]):
cython_mit_mot_out_slices[_d0, _d1] = self.mit_mot_out_slices[_d0][
_d1
]
cython_vector_seqs = np.asarray(self.vector_seqs, dtype="int32") cython_vector_seqs = np.asarray(self.vector_seqs, dtype="int32")
cython_vector_outs = np.asarray(self.vector_outs, dtype="int32") cython_vector_outs = np.asarray(self.vector_outs, dtype="int32")
cython_mitmots_preallocated = np.asarray( cython_mitmots_preallocated = np.asarray(
...@@ -1403,13 +1390,16 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1403,13 +1390,16 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_input_storage = [s.storage for s in self.fn.input_storage] inner_input_storage = [s.storage for s in self.fn.input_storage]
inner_output_storage = [s.storage for s in self.fn.output_storage] inner_output_storage = [s.storage for s in self.fn.output_storage]
inner_input_needs_update = [ inner_input_needs_update = tuple(
inp.update is not None for inp in self.fn.maker.expanded_inputs inp.update is not None for inp in self.fn.maker.expanded_inputs
] )
output_dtypes = [getattr(out, "dtype", None) for out in node.outputs]
from . import scan_perform_ext outer_output_dtypes = tuple(
getattr(out, "dtype", None) for out in node.outputs
)
outer_output_ndims = tuple(
getattr(out, "ndim", None) for out in node.outputs
)
def p(node, inputs, outputs): def p(node, inputs, outputs):
...@@ -1429,7 +1419,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1429,7 +1419,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
tap_array_len, tap_array_len,
cython_vector_seqs, cython_vector_seqs,
cython_vector_outs, cython_vector_outs,
cython_mit_mot_out_slices, self.mit_mot_out_slices,
cython_mitmots_preallocated, cython_mitmots_preallocated,
cython_inps_is_tensor, cython_inps_is_tensor,
cython_outs_is_tensor, cython_outs_is_tensor,
...@@ -1441,7 +1431,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1441,7 +1431,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_destroy_map, cython_destroy_map,
inputs, inputs,
outputs, outputs,
output_dtypes, outer_output_dtypes,
outer_output_ndims,
) )
t_call = time.perf_counter() - t0_call t_call = time.perf_counter() - t0_call
...@@ -1566,7 +1557,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1566,7 +1557,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for idx in range(self.n_outs, self.n_outs + self.n_nit_sot): for idx in range(self.n_outs, self.n_outs + self.n_nit_sot):
out_var = node.outputs[idx] out_var = node.outputs[idx]
if isinstance(out_var, TensorVariable): if isinstance(out_var, TensorVariable):
output_storage[idx][0] = out_var.type.value_zeros(0) output_storage[idx][0] = np.empty(
(0,) * out_var.type.ndim, dtype=out_var.type.dtype
)
else: else:
output_storage[idx][0] = None output_storage[idx][0] = None
return return
...@@ -1875,7 +1868,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1875,7 +1868,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
or output_storage[j][0].shape[1:] != shape[1:] or output_storage[j][0].shape[1:] != shape[1:]
or output_storage[j][0].dtype != dtype or output_storage[j][0].dtype != dtype
): ):
output_storage[j][0] = node.outputs[j].type.value_zeros(shape) output_storage[j][0] = np.empty(
shape, dtype=node.outputs[j].type
)
elif output_storage[j][0].shape[0] != store_steps[j]: elif output_storage[j][0].shape[0] != store_steps[j]:
output_storage[j][0] = output_storage[j][0][: store_steps[j]] output_storage[j][0] = output_storage[j][0][: store_steps[j]]
output_storage[j][0][pos[j]] = inner_output_storage[jout].storage[0] output_storage[j][0][pos[j]] = inner_output_storage[jout].storage[0]
...@@ -1932,7 +1927,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1932,7 +1927,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# This way, there will be no information overwritten # This way, there will be no information overwritten
# before it is read (as it used to happen). # before it is read (as it used to happen).
shape = (pdx,) + output_storage[idx][0].shape[1:] shape = (pdx,) + output_storage[idx][0].shape[1:]
tmp = node.outputs[idx].type.value_zeros(shape) tmp = np.empty(shape, dtype=node.outputs[idx].type)
tmp[:] = output_storage[idx][0][:pdx] tmp[:] = output_storage[idx][0][:pdx]
output_storage[idx][0][: store_steps[idx] - pdx] = output_storage[ output_storage[idx][0][: store_steps[idx] - pdx] = output_storage[
idx idx
...@@ -1941,7 +1936,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1941,7 +1936,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
del tmp del tmp
else: else:
shape = (store_steps[idx] - pdx,) + output_storage[idx][0].shape[1:] shape = (store_steps[idx] - pdx,) + output_storage[idx][0].shape[1:]
tmp = node.outputs[idx].type.value_zeros(shape) tmp = np.empty(shape, dtype=node.outputs[idx].type)
tmp[:] = output_storage[idx][0][pdx:] tmp[:] = output_storage[idx][0][pdx:]
output_storage[idx][0][store_steps[idx] - pdx :] = output_storage[ output_storage[idx][0][store_steps[idx] - pdx :] = output_storage[
idx idx
......
...@@ -75,19 +75,20 @@ def perform( ...@@ -75,19 +75,20 @@ def perform(
tuple tap_array_len, tuple tap_array_len,
numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs, numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs,
numpy.ndarray[numpy.int32_t,ndim=1] vector_outs, numpy.ndarray[numpy.int32_t,ndim=1] vector_outs,
numpy.ndarray[numpy.int32_t,ndim=2] mit_mot_out_slices, tuple mit_mot_out_slices,
numpy.ndarray[numpy.int32_t,ndim=1] mitmots_preallocated, numpy.ndarray[numpy.int32_t,ndim=1] mitmots_preallocated,
numpy.ndarray[numpy.int32_t,ndim=1] inps_is_tensor, numpy.ndarray[numpy.int32_t,ndim=1] inps_is_tensor,
numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor, numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor,
list inner_input_storage, list inner_input_storage,
list inner_output_storage, list inner_output_storage,
bint need_update_inputs, bint need_update_inputs,
list inner_input_needs_update, tuple inner_input_needs_update,
fnct, fnct,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map, numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
list outer_inputs, list outer_inputs,
list outer_outputs, list outer_outputs,
list output_dtypes, tuple outer_output_dtypes,
tuple outer_output_ndims,
): ):
""" """
Parameters Parameters
...@@ -128,7 +129,7 @@ def perform( ...@@ -128,7 +129,7 @@ def perform(
For each output ( mit_mot, mit_sot, sit_sot, nit_sot in this order) For each output ( mit_mot, mit_sot, sit_sot, nit_sot in this order)
the entry is 1 if the corresponding argument is a 1 dimensional the entry is 1 if the corresponding argument is a 1 dimensional
tensor, 0 otherwise. tensor, 0 otherwise.
mit_mot_out_slices : int32 ndarray( can be replaced by list of lists) mit_mot_out_slices
Same as tap_array, but for the output taps of mit_mot sequences Same as tap_array, but for the output taps of mit_mot sequences
inps_is_tensor : int32 ndarray (Can be replaced by a list) inps_is_tensor : int32 ndarray (Can be replaced by a list)
Array of boolean indicating, for every input, whether it is a tensor Array of boolean indicating, for every input, whether it is a tensor
...@@ -143,7 +144,7 @@ def perform( ...@@ -143,7 +144,7 @@ def perform(
need_update_inputs need_update_inputs
A boolean indicating whether or not inner inputs need to be updated. A boolean indicating whether or not inner inputs need to be updated.
inner_input_needs_update inner_input_needs_update
A list of booleans indicating which inner inputs need to be updated. A tuple of booleans indicating which inner inputs need to be updated.
fnct: Function fnct: Function
The compiled Aesara inner-function object. The compiled Aesara inner-function object.
destroy_map destroy_map
...@@ -155,8 +156,10 @@ def perform( ...@@ -155,8 +156,10 @@ def perform(
This is where we need to copy our outputs ( we don't return the This is where we need to copy our outputs ( we don't return the
results, though we can change the code such that we return, and results, though we can change the code such that we return, and
figure things out on the outside - python) figure things out on the outside - python)
output_dtypes outer_output_dtypes
The dtypes for each output. The dtypes for each outer output.
outer_output_ndims
The number of dimensions for each outer output.
""" """
# 1. Unzip the number of steps and sequences. If number of steps is # 1. Unzip the number of steps and sequences. If number of steps is
...@@ -252,7 +255,7 @@ def perform( ...@@ -252,7 +255,7 @@ def perform(
# (The answer is that you shouldn't have a `node` object to # (The answer is that you shouldn't have a `node` object to
# access, because it's not going to produce a very efficient # access, because it's not going to produce a very efficient
# Cython function!) # Cython function!)
outer_outputs[idx][0] = numpy.zeros(0, dtype=output_dtypes[idx]) outer_outputs[idx][0] = numpy.empty((0,) * outer_output_ndims[idx], dtype=outer_output_dtypes[idx])
else: else:
outer_outputs[idx][0] = None outer_outputs[idx][0] = None
return return
...@@ -299,15 +302,13 @@ def perform( ...@@ -299,15 +302,13 @@ def perform(
offset = n_seqs offset = n_seqs
for idx in range(n_outs): for idx in range(n_outs):
if vector_outs[idx] == 1: if vector_outs[idx] == 1:
for tdx in range(tap_array_len[idx]): for tap in tap_array[idx]:
tap = tap_array[idx][tdx]
_idx = (pos[idx]+tap)%store_steps[idx] _idx = (pos[idx]+tap)%store_steps[idx]
inner_input_storage[offset][0] =\ inner_input_storage[offset][0] =\
outer_outputs[idx][0][_idx:<unsigned int>(_idx+1)].reshape(()) outer_outputs[idx][0][_idx:<unsigned int>(_idx+1)].reshape(())
offset += 1 offset += 1
else: else:
for tdx in range(tap_array_len[idx]): for tap in tap_array[idx]:
tap = tap_array[idx][tdx]
_idx = (pos[idx]+tap)%store_steps[idx] _idx = (pos[idx]+tap)%store_steps[idx]
inner_input_storage[offset][0] = outer_outputs[idx][0][_idx] inner_input_storage[offset][0] = outer_outputs[idx][0][_idx]
offset += 1 offset += 1
...@@ -474,7 +475,7 @@ def perform( ...@@ -474,7 +475,7 @@ def perform(
mitmot_out_idx += 1 mitmot_out_idx += 1
mitmot_inp_offset += len(tap_array[j]) mitmot_inp_offset += tap_array_len[j]
# 5.4 Copy over the values for mit_sot/sit_sot outputs # 5.4 Copy over the values for mit_sot/sit_sot outputs
begin = n_mit_mot begin = n_mit_mot
...@@ -520,7 +521,7 @@ def perform( ...@@ -520,7 +521,7 @@ def perform(
outer_outputs[j][0].shape[0] < store_steps[j] or outer_outputs[j][0].shape[0] < store_steps[j] or
outer_outputs[j][0].shape[1:] != shape[1:] or outer_outputs[j][0].shape[1:] != shape[1:] or
outer_outputs[j][0].dtype != dtype ): outer_outputs[j][0].dtype != dtype ):
outer_outputs[j][0] = numpy.zeros(shape, dtype=output_dtypes[j]) outer_outputs[j][0] = numpy.empty(shape, dtype=outer_output_dtypes[j])
elif outer_outputs[j][0].shape[0] != store_steps[j]: elif outer_outputs[j][0].shape[0] != store_steps[j]:
outer_outputs[j][0] = outer_outputs[j][0][:store_steps[j]] outer_outputs[j][0] = outer_outputs[j][0][:store_steps[j]]
outer_outputs[j][0][pos[j]] = inner_output_storage[jout][0] outer_outputs[j][0][pos[j]] = inner_output_storage[jout][0]
...@@ -582,13 +583,13 @@ def perform( ...@@ -582,13 +583,13 @@ def perform(
# This way, there will be no information overwritten # This way, there will be no information overwritten
# before it is read (as it used to happen). # before it is read (as it used to happen).
shape = (pdx,)+ outer_outputs[idx][0].shape[1:] shape = (pdx,)+ outer_outputs[idx][0].shape[1:]
tmp = numpy.zeros(shape, dtype=output_dtypes[idx]) tmp = numpy.empty(shape, dtype=outer_output_dtypes[idx])
tmp[:] = outer_outputs[idx][0][:pdx] tmp[:] = outer_outputs[idx][0][:pdx]
outer_outputs[idx][0][:store_steps[idx]-pdx] = outer_outputs[idx][0][pdx:] outer_outputs[idx][0][:store_steps[idx]-pdx] = outer_outputs[idx][0][pdx:]
outer_outputs[idx][0][store_steps[idx]-pdx:] = tmp outer_outputs[idx][0][store_steps[idx]-pdx:] = tmp
else: else:
shape = (store_steps[idx]-pdx,) + outer_outputs[idx][0].shape[1:] shape = (store_steps[idx]-pdx,) + outer_outputs[idx][0].shape[1:]
tmp = numpy.zeros(shape, dtype=output_dtypes[idx]) tmp = numpy.empty(shape, dtype=outer_output_dtypes[idx])
tmp[:] = outer_outputs[idx][0][pdx:] tmp[:] = outer_outputs[idx][0][pdx:]
outer_outputs[idx][0][store_steps[idx]-pdx:] = outer_outputs[idx][0][:pdx] outer_outputs[idx][0][store_steps[idx]-pdx:] = outer_outputs[idx][0][:pdx]
outer_outputs[idx][0][:store_steps[idx]-pdx] = tmp outer_outputs[idx][0][:store_steps[idx]-pdx] = tmp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论