Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
75b8b833
提交
75b8b833
authored
5月 18, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
6月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove redundant indexing in scan_perform.pyx
上级
4af12379
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
118 行增加
和
95 行删除
+118
-95
scan_perform.c
aesara/scan/c_code/scan_perform.c
+0
-0
scan_perform.pyx
aesara/scan/scan_perform.pyx
+117
-94
scan_perform_ext.py
aesara/scan/scan_perform_ext.py
+1
-1
没有找到文件。
aesara/scan/c_code/scan_perform.c
浏览文件 @
75b8b833
This source diff could not be displayed because it is too large. You can
view the blob
instead.
aesara/scan/scan_perform.pyx
浏览文件 @
75b8b833
...
@@ -58,9 +58,14 @@ import sys
...
@@ -58,9 +58,14 @@ import sys
from aesara.scan.utils import InnerFunctionError
from aesara.scan.utils import InnerFunctionError
numpy.import_array()
def get_version():
def get_version():
return 0.31
7
return 0.31
8
# TODO: We need to get rid of the negative indexing performed with `pos` and `l`.
# @cython.wraparound(False)
@cython.boundscheck(False)
@cython.boundscheck(False)
def perform(
def perform(
const unsigned int n_shared_outs,
const unsigned int n_shared_outs,
...
@@ -160,6 +165,8 @@ def perform(
...
@@ -160,6 +165,8 @@ def perform(
# 1. Unzip the number of steps and sequences. If number of steps is
# 1. Unzip the number of steps and sequences. If number of steps is
# negative flip sequences around, and make n_steps positive
# negative flip sequences around, and make n_steps positive
cdef float t_fn = 0
cdef float t_fn = 0
cdef float t0_fn
cdef float dt_fn
cdef unsigned int n_steps = outer_inputs[0].item()
cdef unsigned int n_steps = outer_inputs[0].item()
cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot
cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot
cdef unsigned int seqs_arg_offset = n_seqs + 1
cdef unsigned int seqs_arg_offset = n_seqs + 1
...
@@ -189,6 +196,16 @@ def perform(
...
@@ -189,6 +196,16 @@ def perform(
cdef unsigned int len_output_storage = (n_mit_mot_outs + n_mit_sot +
cdef unsigned int len_output_storage = (n_mit_mot_outs + n_mit_sot +
n_sit_sot + n_nit_sot +
n_sit_sot + n_nit_sot +
n_shared_outs)
n_shared_outs)
cdef unsigned int mitmot_inp_offset
cdef unsigned int mitmot_out_idx
cdef unsigned int inp_idx
cdef unsigned int inner_inp_idx
cdef unsigned int store_steps_j
cdef unsigned int store_steps_idx
cdef int mintaps_idx
cdef unsigned int sh0
cdef long pos_j
cdef long pos_idx
if n_steps < 0:
if n_steps < 0:
# History, in the past, this was used for backward
# History, in the past, this was used for backward
...
@@ -220,23 +237,29 @@ def perform(
...
@@ -220,23 +237,29 @@ def perform(
# 2.1 Create storage space for outputs
# 2.1 Create storage space for outputs
for idx in range(n_outs):
for idx in range(n_outs):
outer_outputs_idx = outer_outputs[idx]
if destroy_map[idx] != 0:
if destroy_map[idx] != 0:
# ^ Case 1. Outputs should be computed inplace of their
# ^ Case 1. Outputs should be computed inplace of their
# initial state
# initial state
outer_outputs[idx][0] = outer_inputs[ <unsigned int>(1+ n_seqs + idx)]
outer_outputs_idx[0] = outer_inputs[ <unsigned int>(1+ n_seqs + idx)]
elif ( outer_outputs[idx][0] is not None and
continue
outer_outputs[idx][0].shape[1:] == outer_inputs[<unsigned int>(1+ n_seqs + idx)].shape[1:]
and outer_outputs[idx][0].shape[0] >= store_steps[idx] ):
outer_outputs_idx_0 = outer_outputs_idx[0]
if ( outer_outputs_idx_0 is not None and
outer_outputs_idx_0.shape[1:] == outer_inputs[<unsigned int>(1+ n_seqs + idx)].shape[1:]
and outer_outputs_idx_0.shape[0] >= store_steps[idx] ):
# Put in the values of the initial state
# Put in the values of the initial state
outer_outputs
[idx][0] = outer_outputs[idx][0]
[:store_steps[idx]]
outer_outputs
_idx[0] = outer_outputs_idx_0
[:store_steps[idx]]
if idx > n_mit_mot:
if idx > n_mit_mot:
# TODO FIXME: Do not use wrapped indexing!
l = - mintaps[idx]
l = - mintaps[idx]
outer_outputs[idx][0][:l] = outer_inputs[<unsigned int>(seqs_arg_offset +
outer_outputs_idx_0[:l] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)][:l]
idx)][:l]
else:
else:
outer_outputs
[idx][0]
[:] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)]
outer_outputs
_idx_0
[:] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)]
else:
else:
outer_outputs
[idx]
[0] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)].copy()
outer_outputs
_idx
[0] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)].copy()
if n_steps == 0:
if n_steps == 0:
for idx in range(n_outs, n_outs + n_nit_sot):
for idx in range(n_outs, n_outs + n_nit_sot):
...
@@ -252,12 +275,13 @@ def perform(
...
@@ -252,12 +275,13 @@ def perform(
return 0.0, 0
return 0.0, 0
for idx in range(lenpos):
for idx in range(lenpos):
# TODO FIXME: Do not use wrapped indexing!
pos[idx] = -mintaps[idx] % store_steps[idx]
pos[idx] = -mintaps[idx] % store_steps[idx]
offset = nit_sot_arg_offset + n_nit_sot
offset = nit_sot_arg_offset + n_nit_sot
other_args = outer_inputs[offset:]
other_args = outer_inputs[offset:]
nb_mitmot_in = 0
cdef unsigned int
nb_mitmot_in = 0
for idx in range(n_mit_mot):
for idx in range(n_mit_mot):
nb_mitmot_in += tap_array_len[idx]
nb_mitmot_in += tap_array_len[idx]
...
@@ -290,16 +314,20 @@ def perform(
...
@@ -290,16 +314,20 @@ def perform(
offset = n_seqs
offset = n_seqs
for idx in range(n_outs):
for idx in range(n_outs):
pos_idx = pos[idx]
store_steps_idx = store_steps[idx]
outer_outputs_idx = outer_outputs[idx]
if vector_outs[idx] == 1:
if vector_outs[idx] == 1:
for tap in tap_array[idx]:
for tap in tap_array[idx]:
_idx = (pos
[idx]+tap)%store_steps[idx]
_idx = (pos
_idx + tap) % store_steps_idx
inner_input_storage[offset][0] =\
inner_input_storage[offset][0] =\
outer_outputs
[idx][0][_idx:<unsigned int>(_idx+
1)].reshape(())
outer_outputs
_idx[0][_idx:<unsigned int>(_idx +
1)].reshape(())
offset += 1
offset += 1
else:
else:
for tap in tap_array[idx]:
for tap in tap_array[idx]:
_idx = (pos
[idx]+tap)%store_steps[idx]
_idx = (pos
_idx + tap) % store_steps_idx
inner_input_storage[offset][0] = outer_outputs
[idx]
[0][_idx]
inner_input_storage[offset][0] = outer_outputs
_idx
[0][_idx]
offset += 1
offset += 1
...
@@ -384,7 +412,7 @@ def perform(
...
@@ -384,7 +412,7 @@ def perform(
try:
try:
fn()
fn()
except Exception as exc:
except Exception as exc:
raise InnerFunctionError(exc, sys.exc_info()[
-1
])
raise InnerFunctionError(exc, sys.exc_info()[
2
])
dt_fn = time.time() - t0_fn
dt_fn = time.time() - t0_fn
t_fn += dt_fn
t_fn += dt_fn
...
@@ -398,34 +426,33 @@ def perform(
...
@@ -398,34 +426,33 @@ def perform(
mitmot_inp_offset = 0
mitmot_inp_offset = 0
mitmot_out_idx = 0
mitmot_out_idx = 0
for j in range(n_mit_mot):
for j in range(n_mit_mot):
tap_array_j = tap_array[j]
pos_j = pos[j]
outer_outputs_j_0 = outer_outputs[j][0]
for k in mit_mot_out_slices[j]:
for k in mit_mot_out_slices[j]:
if mitmots_preallocated[
<unsigned int>
mitmot_out_idx]:
if mitmots_preallocated[mitmot_out_idx]:
# This output tap has been preallocated.
# This output tap has been preallocated.
inp_idx =
(mitmot_inp_offset + tap_array[j].index(k)
)
inp_idx =
mitmot_inp_offset + tap_array_j.index(k
)
inner_inp_idx = n_seqs + inp_idx
inner_inp_idx = n_seqs + inp_idx
# Verify whether the input points to the same data as
# Verify whether the input points to the same data as
# it did before the execution of the inner function.
# it did before the execution of the inner function.
old_var = old_mitmot_input_storage[inp_idx]
old_var = old_mitmot_input_storage[inp_idx]
new_var = inner_input_storage[inner_inp_idx][0]
new_var = inner_input_storage[inner_inp_idx][0]
if old_var is new_var:
old_data = old_mitmot_input_data[inp_idx]
same_data = (new_var.data == old_data)
else:
same_data = False
# If the corresponding input storage has been replaced,
# If the corresponding input storage has been replaced,
# recover the value as usual. Otherwise, the input was
# recover the value as usual. Otherwise, the input was
# modified inplace and nothing needs to be done.
# modified inplace and nothing needs to be done.
if not same_data:
if old_var is not new_var or old_mitmot_input_data[inp_idx] != new_var.data:
outer_outputs[j][0][<unsigned int>(k + pos[j])] = \
outer_outputs_j_0[<unsigned int>(k + pos_j)] = \
inner_input_storage[<unsigned int>(inner_inp_idx)][0]
inner_input_storage[inner_inp_idx][0]
else:
else:
# This output tap has not been preallocated, recover
# This output tap has not been preallocated, recover
# its value as usual
# its value as usual
outer_outputs
[j][0][<unsigned int>(k + pos[j]
)] = \
outer_outputs
_j_0[<unsigned int>(k + pos_j
)] = \
inner_output_storage[
<unsigned int>
offset_out][0]
inner_output_storage[offset_out][0]
offset_out += 1
offset_out += 1
mitmot_out_idx += 1
mitmot_out_idx += 1
...
@@ -439,72 +466,63 @@ def perform(
...
@@ -439,72 +466,63 @@ def perform(
for j in range(begin, end):
for j in range(begin, end):
jout = j + offset_out
outer_outputs_j = outer_outputs[j]
# Copy the output value to `outer_outputs`, if necessary
# Copy the output value to `outer_outputs`, if necessary
if store_steps[j] == 1 or vector_outs[j] == 1:
if store_steps[j] == 1 or vector_outs[j] == 1:
outer_outputs
[j][0][pos[j]] = inner_output_storage[<unsigned int>(offset_out+j)
][0]
outer_outputs
_j[0][pos[j]] = inner_output_storage[jout
][0]
else:
else:
# Check whether the initialization of the output storage map
# Check whether the initialization of the output storage map
# for this output has been reused.
# for this output has been reused.
old_var = old_output_storage[offset_out + j]
old_var = old_output_storage[jout]
old_data = old_output_data[offset_out + j]
old_data = old_output_data[jout]
new_var = inner_output_storage[offset_out + j][0]
new_var = inner_output_storage[jout][0]
if old_var is new_var:
if old_data is None:
output_reused = False
else:
output_reused = (new_var.data == old_data)
else:
output_reused = False
if not output_reused:
outer_outputs[j][0][pos[j]] = \
inner_output_storage[<unsigned int>(offset_out+j)][0]
if old_var is not new_var or old_data is None:
outer_outputs_j[0][pos[j]] = new_var
# 5.5 Copy over the values for nit_sot outputs
# 5.5 Copy over the values for nit_sot outputs
begin = end
begin = end
end += n_nit_sot
end += n_nit_sot
for j in range(begin,end):
for j in range(begin,end):
jout = j + offset_out
if i == 0:
if i == 0:
jout = j+offset_out
store_steps_j = store_steps[j]
shape = (store_steps[j],) + inner_output_storage[jout][0].shape
inner_output_storage_jout_0 = inner_output_storage[jout][0]
dtype = inner_output_storage[jout][0].dtype
shape = (store_steps_j,) + inner_output_storage_jout_0.shape
if (outer_outputs[j][0] is None or
dtype = inner_output_storage_jout_0.dtype
outer_outputs[j][0].shape[0] < store_steps[j] or
outer_outputs_j = outer_outputs[j]
outer_outputs[j][0].shape[1:] != shape[1:] or
outer_outputs_j_0 = outer_outputs_j[0]
outer_outputs[j][0].dtype != dtype ):
outer_outputs[j][0] = numpy.empty(shape, dtype=outer_output_dtypes[j])
if (
elif outer_outputs[j][0].shape[0] != store_steps[j]:
outer_outputs_j_0 is None or
outer_outputs[j][0] = outer_outputs[j][0][:store_steps[j]]
outer_outputs_j_0.shape[0] < store_steps_j or
outer_outputs[j][0][pos[j]] = inner_output_storage[jout][0]
outer_outputs_j_0.shape[1:] != shape[1:] or
outer_outputs_j_0.dtype != dtype
):
new_outer_outputs_j_0 = numpy.empty(shape, dtype=outer_output_dtypes[j])
elif outer_outputs_j_0.shape[0] != store_steps_j:
new_outer_outputs_j_0 = outer_outputs_j_0[:store_steps_j]
else:
new_outer_outputs_j_0 = outer_outputs_j_0
new_outer_outputs_j_0[pos[j]] = inner_output_storage_jout_0
outer_outputs_j[0] = new_outer_outputs_j_0
elif store_steps[j] == 1 or vector_outs[j] == 1:
elif store_steps[j] == 1 or vector_outs[j] == 1:
outer_outputs[j][0][pos[j]] = inner_output_storage[j
+offset_
out][0]
outer_outputs[j][0][pos[j]] = inner_output_storage[jout][0]
else:
else:
# Check whether the initialization of the output storage map
# Check whether the initialization of the output storage map
# for this output has been reused.
# for this output has been reused.
old_var = old_output_storage[offset_out + j]
old_var = old_output_storage[jout]
old_data = old_output_data[offset_out + j]
old_data = old_output_data[jout]
new_var = inner_output_storage[offset_out + j][0]
new_var = inner_output_storage[jout][0]
if old_var is new_var:
if old_data is None:
if old_var is not new_var or old_data is None:
output_reused = False
outer_outputs[j][0][pos[j]] = new_var
else:
output_reused = (new_var.data == old_data)
else:
output_reused = False
if not output_reused:
try:
outer_outputs[j][0][pos[j]] = inner_output_storage[j+offset_out][0]
except ValueError as e:
if i == 0:
raise
raise ValueError(
"An output of the Scan has changed shape. "
"This may be caused by a push-out optimization."
" Try adding 'optimizer_excluding=scan_pushout'"
" to your Aesara flags.")
# 5.6 Copy over the values for outputs corresponding to shared
# 5.6 Copy over the values for outputs corresponding to shared
# variables
# variables
...
@@ -522,35 +540,40 @@ def perform(
...
@@ -522,35 +540,40 @@ def perform(
begin = n_mit_mot
begin = n_mit_mot
end = n_outs + n_nit_sot
end = n_outs + n_nit_sot
for idx in range(begin, end):
for idx in range(begin, end):
if ( store_steps[idx] < i-mintaps[idx] and
outer_outputs_idx = outer_outputs[idx]
pos[idx] < store_steps[idx] ):
outer_outputs_idx_0 = outer_outputs_idx[0]
store_steps_idx = store_steps[idx]
mintaps_idx = mintaps[idx]
pdx = pos[idx]
pdx = pos[idx]
if (store_steps_idx < i - mintaps_idx and pdx < store_steps_idx ):
if pdx >= store_steps
[idx]//
2 :
if pdx >= store_steps
_idx //
2 :
# It seems inefficient to copy the bigger part of the
# It seems inefficient to copy the bigger part of the
# array over, and back, but it is the only way that
# array over, and back, but it is the only way that
# there is no overlap in the areas of out[idx][0] that
# there is no overlap in the areas of out[idx][0] that
# are read and written.
# are read and written.
# This way, there will be no information overwritten
# This way, there will be no information overwritten
# before it is read (as it used to happen).
# before it is read (as it used to happen).
shape = (pdx,)+ outer_outputs
[idx][0]
.shape[1:]
shape = (pdx,)+ outer_outputs
_idx_0
.shape[1:]
tmp = numpy.empty(shape, dtype=outer_output_dtypes[idx])
tmp = numpy.empty(shape, dtype=outer_output_dtypes[idx])
tmp[:] = outer_outputs
[idx][0]
[:pdx]
tmp[:] = outer_outputs
_idx_0
[:pdx]
outer_outputs
[idx][0][:store_steps[idx]-pdx] = outer_outputs[idx][0]
[pdx:]
outer_outputs
_idx_0[:store_steps_idx - pdx] = outer_outputs_idx_0
[pdx:]
outer_outputs
[idx][0][store_steps[idx]-
pdx:] = tmp
outer_outputs
_idx_0[store_steps_idx -
pdx:] = tmp
else:
else:
shape = (store_steps
[idx]-pdx,) + outer_outputs[idx][0]
.shape[1:]
shape = (store_steps
_idx - pdx,) + outer_outputs_idx_0
.shape[1:]
tmp = numpy.empty(shape, dtype=outer_output_dtypes[idx])
tmp = numpy.empty(shape, dtype=outer_output_dtypes[idx])
tmp[:] = outer_outputs
[idx][0]
[pdx:]
tmp[:] = outer_outputs
_idx_0
[pdx:]
outer_outputs
[idx][0][store_steps[idx]-pdx:] = outer_outputs[idx][0]
[:pdx]
outer_outputs
_idx_0[store_steps_idx - pdx:] = outer_outputs_idx_0
[:pdx]
outer_outputs
[idx][0][:store_steps[idx]-
pdx] = tmp
outer_outputs
_idx_0[:store_steps_idx -
pdx] = tmp
# This would normally happen only when doing truncated
# This would normally happen only when doing truncated
# backpropagation through time. In such a scenario Scan is
# backpropagation through time. In such a scenario Scan is
# expected to return 0 for all entries for which the gradient is
# expected to return 0 for all entries for which the gradient is
# not actually computed
# not actually computed
elif store_steps
[idx] > i - mintaps[idx]
:
elif store_steps
_idx > i - mintaps_idx
:
outer_outputs
[idx][0][i - mintaps[idx]
:] = 0
outer_outputs
_idx_0[i - mintaps_idx
:] = 0
# This is a fix for a bug introduced by while. If you say
# This is a fix for a bug introduced by while. If you say
# you want to loop up to a condition, you expect the output
# you want to loop up to a condition, you expect the output
...
@@ -566,8 +589,8 @@ def perform(
...
@@ -566,8 +589,8 @@ def perform(
# to do boundschecks). The directive is used to make the
# to do boundschecks). The directive is used to make the
# code faster, so this workaround is better then removing
# code faster, so this workaround is better then removing
# the directive.
# the directive.
sh0 = outer_outputs
[idx][0]
.shape[0]
sh0 = outer_outputs
_idx_0
.shape[0]
outer_outputs
[idx][0] = outer_outputs[idx][0]
[:sh0-(n_steps - i)]
outer_outputs
_idx[0] = outer_outputs_idx_0
[:sh0-(n_steps - i)]
# We never reuse the input or output storage of the
# We never reuse the input or output storage of the
# inner function so we clear it.
# inner function so we clear it.
...
...
aesara/scan/scan_perform_ext.py
浏览文件 @
75b8b833
...
@@ -23,7 +23,7 @@ if not config.cxx:
...
@@ -23,7 +23,7 @@ if not config.cxx:
_logger
=
logging
.
getLogger
(
"aesara.scan.scan_perform"
)
_logger
=
logging
.
getLogger
(
"aesara.scan.scan_perform"
)
version
=
0.31
7
# must match constant returned in function get_version()
version
=
0.31
8
# must match constant returned in function get_version()
need_reload
=
False
need_reload
=
False
scan_perform
:
Optional
[
ModuleType
]
=
None
scan_perform
:
Optional
[
ModuleType
]
=
None
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论