提交 c997333d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use np.empty instead of np.zeros to allocate memory in Scan

上级 2ee85105
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -45,10 +45,10 @@ relies on the following elements to work properly :
import dataclasses
import itertools
import logging
import time
from collections import OrderedDict
from itertools import product
from typing import Callable, List, Optional, Union
import numpy as np
......@@ -530,7 +530,7 @@ class ScanMethodsMixin:
inner_iidxs = var_mappings["inner_inp_from_outer_out"][outer_oidx]
inner_oidxs = var_mappings["inner_out_from_outer_out"][outer_oidx]
for (inner_iidx, inner_oidx) in itertools.product(inner_iidxs, inner_oidxs):
for (inner_iidx, inner_oidx) in product(inner_iidxs, inner_oidxs):
type_input = self.inputs[inner_iidx].type
type_output = self.outputs[inner_oidx].type
......@@ -1363,25 +1363,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if impl == "py":
raise MissingGXX
from . import scan_perform_ext
cython_mintaps = np.asarray(self.mintaps, dtype="int32")
tap_array_len = tuple(len(x) for x in self.tap_array)
cython_mit_mot_out_nslices = np.asarray(
[len(x) for x in self.mit_mot_out_slices], dtype="int32"
)
if len(self.mit_mot_out_slices) == 0:
d1 = 0
else:
d1 = np.max(cython_mit_mot_out_nslices)
d0 = len(self.mit_mot_out_slices)
cython_mit_mot_out_slices = np.zeros((d0, d1), dtype="int32")
for _d0 in range(d0):
for _d1 in range(cython_mit_mot_out_nslices[_d0]):
cython_mit_mot_out_slices[_d0, _d1] = self.mit_mot_out_slices[_d0][
_d1
]
cython_vector_seqs = np.asarray(self.vector_seqs, dtype="int32")
cython_vector_outs = np.asarray(self.vector_outs, dtype="int32")
cython_mitmots_preallocated = np.asarray(
......@@ -1403,13 +1390,16 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_input_storage = [s.storage for s in self.fn.input_storage]
inner_output_storage = [s.storage for s in self.fn.output_storage]
inner_input_needs_update = [
inner_input_needs_update = tuple(
inp.update is not None for inp in self.fn.maker.expanded_inputs
]
output_dtypes = [getattr(out, "dtype", None) for out in node.outputs]
)
from . import scan_perform_ext
outer_output_dtypes = tuple(
getattr(out, "dtype", None) for out in node.outputs
)
outer_output_ndims = tuple(
getattr(out, "ndim", None) for out in node.outputs
)
def p(node, inputs, outputs):
......@@ -1429,7 +1419,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
tap_array_len,
cython_vector_seqs,
cython_vector_outs,
cython_mit_mot_out_slices,
self.mit_mot_out_slices,
cython_mitmots_preallocated,
cython_inps_is_tensor,
cython_outs_is_tensor,
......@@ -1441,7 +1431,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_destroy_map,
inputs,
outputs,
output_dtypes,
outer_output_dtypes,
outer_output_ndims,
)
t_call = time.perf_counter() - t0_call
......@@ -1566,7 +1557,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for idx in range(self.n_outs, self.n_outs + self.n_nit_sot):
out_var = node.outputs[idx]
if isinstance(out_var, TensorVariable):
output_storage[idx][0] = out_var.type.value_zeros(0)
output_storage[idx][0] = np.empty(
(0,) * out_var.type.ndim, dtype=out_var.type.dtype
)
else:
output_storage[idx][0] = None
return
......@@ -1875,7 +1868,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
or output_storage[j][0].shape[1:] != shape[1:]
or output_storage[j][0].dtype != dtype
):
output_storage[j][0] = node.outputs[j].type.value_zeros(shape)
output_storage[j][0] = np.empty(
shape, dtype=node.outputs[j].type
)
elif output_storage[j][0].shape[0] != store_steps[j]:
output_storage[j][0] = output_storage[j][0][: store_steps[j]]
output_storage[j][0][pos[j]] = inner_output_storage[jout].storage[0]
......@@ -1932,7 +1927,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# This way, there will be no information overwritten
# before it is read (as it used to happen).
shape = (pdx,) + output_storage[idx][0].shape[1:]
tmp = node.outputs[idx].type.value_zeros(shape)
tmp = np.empty(shape, dtype=node.outputs[idx].type)
tmp[:] = output_storage[idx][0][:pdx]
output_storage[idx][0][: store_steps[idx] - pdx] = output_storage[
idx
......@@ -1941,7 +1936,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
del tmp
else:
shape = (store_steps[idx] - pdx,) + output_storage[idx][0].shape[1:]
tmp = node.outputs[idx].type.value_zeros(shape)
tmp = np.empty(shape, dtype=node.outputs[idx].type)
tmp[:] = output_storage[idx][0][pdx:]
output_storage[idx][0][store_steps[idx] - pdx :] = output_storage[
idx
......
......@@ -75,19 +75,20 @@ def perform(
tuple tap_array_len,
numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs,
numpy.ndarray[numpy.int32_t,ndim=1] vector_outs,
numpy.ndarray[numpy.int32_t,ndim=2] mit_mot_out_slices,
tuple mit_mot_out_slices,
numpy.ndarray[numpy.int32_t,ndim=1] mitmots_preallocated,
numpy.ndarray[numpy.int32_t,ndim=1] inps_is_tensor,
numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor,
list inner_input_storage,
list inner_output_storage,
bint need_update_inputs,
list inner_input_needs_update,
tuple inner_input_needs_update,
fnct,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
list outer_inputs,
list outer_outputs,
list output_dtypes,
tuple outer_output_dtypes,
tuple outer_output_ndims,
):
"""
Parameters
......@@ -128,7 +129,7 @@ def perform(
For each output ( mit_mot, mit_sot, sit_sot, nit_sot in this order)
the entry is 1 if the corresponding argument is a 1 dimensional
tensor, 0 otherwise.
mit_mot_out_slices : int32 ndarray( can be replaced by list of lists)
mit_mot_out_slices
Same as tap_array, but for the output taps of mit_mot sequences
inps_is_tensor : int32 ndarray (Can be replaced by a list)
Array of boolean indicating, for every input, whether it is a tensor
......@@ -143,7 +144,7 @@ def perform(
need_update_inputs
A boolean indicating whether or not inner inputs need to be updated.
inner_input_needs_update
A list of booleans indicating which inner inputs need to be updated.
A tuple of booleans indicating which inner inputs need to be updated.
fnct: Function
The compiled Aesara inner-function object.
destroy_map
......@@ -155,8 +156,10 @@ def perform(
This is where we need to copy our outputs ( we don't return the
results, though we can change the code such that we return, and
figure things out on the outside - python)
output_dtypes
The dtypes for each output.
outer_output_dtypes
The dtypes for each outer output.
outer_output_ndims
The number of dimensions for each outer output.
"""
# 1. Unzip the number of steps and sequences. If number of steps is
......@@ -252,7 +255,7 @@ def perform(
# (The answer is that you shouldn't have a `node` object to
# access, because it's not going to produce a very efficient
# Cython function!)
outer_outputs[idx][0] = numpy.zeros(0, dtype=output_dtypes[idx])
outer_outputs[idx][0] = numpy.empty((0,) * outer_output_ndims[idx], dtype=outer_output_dtypes[idx])
else:
outer_outputs[idx][0] = None
return
......@@ -299,15 +302,13 @@ def perform(
offset = n_seqs
for idx in range(n_outs):
if vector_outs[idx] == 1:
for tdx in range(tap_array_len[idx]):
tap = tap_array[idx][tdx]
for tap in tap_array[idx]:
_idx = (pos[idx]+tap)%store_steps[idx]
inner_input_storage[offset][0] =\
outer_outputs[idx][0][_idx:<unsigned int>(_idx+1)].reshape(())
offset += 1
else:
for tdx in range(tap_array_len[idx]):
tap = tap_array[idx][tdx]
for tap in tap_array[idx]:
_idx = (pos[idx]+tap)%store_steps[idx]
inner_input_storage[offset][0] = outer_outputs[idx][0][_idx]
offset += 1
......@@ -474,7 +475,7 @@ def perform(
mitmot_out_idx += 1
mitmot_inp_offset += len(tap_array[j])
mitmot_inp_offset += tap_array_len[j]
# 5.4 Copy over the values for mit_sot/sit_sot outputs
begin = n_mit_mot
......@@ -520,7 +521,7 @@ def perform(
outer_outputs[j][0].shape[0] < store_steps[j] or
outer_outputs[j][0].shape[1:] != shape[1:] or
outer_outputs[j][0].dtype != dtype ):
outer_outputs[j][0] = numpy.zeros(shape, dtype=output_dtypes[j])
outer_outputs[j][0] = numpy.empty(shape, dtype=outer_output_dtypes[j])
elif outer_outputs[j][0].shape[0] != store_steps[j]:
outer_outputs[j][0] = outer_outputs[j][0][:store_steps[j]]
outer_outputs[j][0][pos[j]] = inner_output_storage[jout][0]
......@@ -582,13 +583,13 @@ def perform(
# This way, there will be no information overwritten
# before it is read (as it used to happen).
shape = (pdx,)+ outer_outputs[idx][0].shape[1:]
tmp = numpy.zeros(shape, dtype=output_dtypes[idx])
tmp = numpy.empty(shape, dtype=outer_output_dtypes[idx])
tmp[:] = outer_outputs[idx][0][:pdx]
outer_outputs[idx][0][:store_steps[idx]-pdx] = outer_outputs[idx][0][pdx:]
outer_outputs[idx][0][store_steps[idx]-pdx:] = tmp
else:
shape = (store_steps[idx]-pdx,) + outer_outputs[idx][0].shape[1:]
tmp = numpy.zeros(shape, dtype=output_dtypes[idx])
tmp = numpy.empty(shape, dtype=outer_output_dtypes[idx])
tmp[:] = outer_outputs[idx][0][pdx:]
outer_outputs[idx][0][store_steps[idx]-pdx:] = outer_outputs[idx][0][:pdx]
outer_outputs[idx][0][:store_steps[idx]-pdx] = tmp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论