Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2ee85105
提交
2ee85105
authored
1月 21, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
1月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove references to Op and Apply objects in Scan's Cython code
上级
4ca744f0
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
88 行增加
和
86 行删除
+88
-86
scan_perform.c
aesara/scan/c_code/scan_perform.c
+0
-0
op.py
aesara/scan/op.py
+31
-17
scan_perform.pyx
aesara/scan/scan_perform.pyx
+57
-69
没有找到文件。
aesara/scan/c_code/scan_perform.c
浏览文件 @
2ee85105
This source diff could not be displayed because it is too large. You can
view the blob
instead.
aesara/scan/op.py
浏览文件 @
2ee85105
...
...
@@ -1362,19 +1362,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
try
:
if
impl
==
"py"
:
raise
MissingGXX
cython_mintaps
=
np
.
asarray
(
self
.
mintaps
,
dtype
=
"int32"
)
cython_tap_array_len
=
np
.
asarray
(
[
len
(
x
)
for
x
in
self
.
tap_array
],
dtype
=
"int32"
)
if
len
(
self
.
tap_array
)
==
0
:
d1
=
0
else
:
d1
=
np
.
max
(
cython_tap_array_len
)
d0
=
len
(
self
.
tap_array
)
cython_tap_array
=
np
.
zeros
((
d0
,
d1
),
dtype
=
"int32"
)
for
_d0
in
range
(
d0
):
for
_d1
in
range
(
cython_tap_array_len
[
_d0
]):
cython_tap_array
[
_d0
,
_d1
]
=
self
.
tap_array
[
_d0
][
_d1
]
tap_array_len
=
tuple
(
len
(
x
)
for
x
in
self
.
tap_array
)
cython_mit_mot_out_nslices
=
np
.
asarray
(
[
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
],
dtype
=
"int32"
)
...
...
@@ -1411,10 +1403,19 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_input_storage
=
[
s
.
storage
for
s
in
self
.
fn
.
input_storage
]
inner_output_storage
=
[
s
.
storage
for
s
in
self
.
fn
.
output_storage
]
inner_input_needs_update
=
[
inp
.
update
is
not
None
for
inp
in
self
.
fn
.
maker
.
expanded_inputs
]
output_dtypes
=
[
getattr
(
out
,
"dtype"
,
None
)
for
out
in
node
.
outputs
]
from
.
import
scan_perform_ext
def
p
(
node
,
inputs
,
outputs
):
return
scan_perform_ext
.
perform
(
t0_call
=
time
.
perf_counter
()
t_fn
=
scan_perform_ext
.
perform
(
self
.
n_shared_outs
,
self
.
n_mit_mot_outs
,
self
.
n_seqs
,
...
...
@@ -1424,8 +1425,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self
.
n_nit_sot
,
self
.
as_while
,
cython_mintaps
,
cython_
tap_array
,
cython_
tap_array_len
,
self
.
tap_array
,
tap_array_len
,
cython_vector_seqs
,
cython_vector_outs
,
cython_mit_mot_out_slices
,
...
...
@@ -1434,14 +1435,27 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_outs_is_tensor
,
inner_input_storage
,
inner_output_storage
,
getattr
(
self
.
fn
.
fn
,
"need_update_inputs"
,
True
),
inner_input_needs_update
,
self
.
fn
,
cython_destroy_map
,
inputs
,
outputs
,
self
,
node
,
output_dtypes
,
)
t_call
=
time
.
perf_counter
()
-
t0_call
if
hasattr
(
self
.
fn
.
maker
,
"profile"
):
profile
=
self
.
fn
.
maker
.
profile
if
type
(
profile
)
is
not
bool
and
profile
:
profile
.
vm_call_time
+=
t_fn
profile
.
callcount
+=
1
profile
.
nbsteps
+=
outputs
[
0
]
profile
.
call_time
+=
t_call
if
hasattr
(
self
.
fn
.
fn
,
"update_profile"
):
self
.
fn
.
fn
.
update_profile
(
profile
)
except
(
ImportError
,
MissingGXX
):
p
=
self
.
perform
...
...
aesara/scan/scan_perform.pyx
浏览文件 @
2ee85105
...
...
@@ -62,31 +62,33 @@ def get_version():
@cython.boundscheck(False)
def perform(
unsigned int n_shared_outs,
unsigned int n_mit_mot_outs,
unsigned int n_seqs,
unsigned int n_mit_mot,
unsigned int n_mit_sot,
unsigned int n_sit_sot,
unsigned int n_nit_sot,
bint as_while,
numpy.ndarray[numpy.int32_t,ndim=1] mintaps,
numpy.ndarray[numpy.int32_t,ndim=2] tap_array,
numpy.ndarray[numpy.int32_t,ndim=1] tap_array_len,
numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs,
numpy.ndarray[numpy.int32_t,ndim=1] vector_outs,
numpy.ndarray[numpy.int32_t,ndim=2] mit_mot_out_slices,
numpy.ndarray[numpy.int32_t,ndim=1] mitmots_preallocated,
numpy.ndarray[numpy.int32_t,ndim=1] inps_is_tensor,
numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor,
list inner_input_storage,
list inner_output_storage,
fnct,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
list outer_inputs,
list outer_outputs,
self,
node):
unsigned int n_shared_outs,
unsigned int n_mit_mot_outs,
unsigned int n_seqs,
unsigned int n_mit_mot,
unsigned int n_mit_sot,
unsigned int n_sit_sot,
unsigned int n_nit_sot,
bint as_while,
numpy.ndarray[numpy.int32_t,ndim=1] mintaps,
tuple tap_array,
tuple tap_array_len,
numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs,
numpy.ndarray[numpy.int32_t,ndim=1] vector_outs,
numpy.ndarray[numpy.int32_t,ndim=2] mit_mot_out_slices,
numpy.ndarray[numpy.int32_t,ndim=1] mitmots_preallocated,
numpy.ndarray[numpy.int32_t,ndim=1] inps_is_tensor,
numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor,
list inner_input_storage,
list inner_output_storage,
bint need_update_inputs,
list inner_input_needs_update,
fnct,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
list outer_inputs,
list outer_outputs,
list output_dtypes,
):
"""
Parameters
----------
...
...
@@ -110,13 +112,13 @@ def perform(
away input tap from current position. For example, if the taps where [-2,
-5, -9], the mintap would be -9. For sit_sot this is always -1 since
is the only allowed tap.
tap_array
: int32 ndarray( can be replaced by a list of list in python if better)
tap_array
For each of the mit_mot, mit_sot, sit_sot (the first dimension) says
which are the corresponding input taps. While this is a matrix, not all
values in a row are needed and tap_array_len is there to say up to
which entry we are dealing with valid taps ( afterwards there are
just 0s to ensure the fix format)
tap_array_len
: int32 ndarray( can be replaced by a list if better)
tap_array_len
For each of the mit_mot, mit_sot, sit_sot says how many input taps
each has. For sit_sot this will always be 1.
vector_seqs: int32 ndarray (can be replaced by a list of bools if better)
...
...
@@ -138,6 +140,10 @@ def perform(
The storage locations for the inner-function's inputs.
inner_output_storage
The storage locations for the inner-function's outputs.
need_update_inputs
A boolean indicating whether or not inner inputs need to be updated.
inner_input_needs_update
A list of booleans indicating which inner inputs need to be updated.
fnct: Function
The compiled Aesara inner-function object.
destroy_map
...
...
@@ -149,15 +155,13 @@ def perform(
This is where we need to copy our outputs ( we don't return the
results, though we can change the code such that we return, and
figure things out on the outside - python)
self: python object
The scan op itself. I only use it to attach to it some timing
information .. but I don;t need to.
output_dtypes
The dtypes for each output.
"""
# 1. Unzip the number of steps and sequences. If number of steps is
# negative flip sequences around, and make n_steps positive
t0_call = time.time()
t_fn = 0
cdef unsigned int t_fn = 0
cdef unsigned int n_steps = outer_inputs[0].item()
cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot
cdef unsigned int seqs_arg_offset = n_seqs + 1
...
...
@@ -191,7 +195,6 @@ def perform(
n_sit_sot + n_nit_sot +
n_shared_outs)
if n_steps < 0:
# History, in the past, this was used for backward
# scan. Now we reverse the inputs outside of scan.
...
...
@@ -249,7 +252,7 @@ def perform(
# (The answer is that you shouldn't have a `node` object to
# access, because it's not going to produce a very efficient
# Cython function!)
outer_outputs[idx][0] = n
ode.outputs[idx].type.value_zeros(0
)
outer_outputs[idx][0] = n
umpy.zeros(0, dtype=output_dtypes[idx]
)
else:
outer_outputs[idx][0] = None
return
...
...
@@ -297,14 +300,14 @@ def perform(
for idx in range(n_outs):
if vector_outs[idx] == 1:
for tdx in range(tap_array_len[idx]):
tap = tap_array[idx
,
tdx]
tap = tap_array[idx
][
tdx]
_idx = (pos[idx]+tap)%store_steps[idx]
inner_input_storage[offset][0] =\
outer_outputs[idx][0][_idx:<unsigned int>(_idx+1)].reshape(())
offset += 1
else:
for tdx in range(tap_array_len[idx]):
tap = tap_array[idx
,
tdx]
tap = tap_array[idx
][
tdx]
_idx = (pos[idx]+tap)%store_steps[idx]
inner_input_storage[offset][0] = outer_outputs[idx][0][_idx]
offset += 1
...
...
@@ -416,20 +419,19 @@ def perform(
dt_fn = time.time() - t0_fn
t_fn += dt_fn
if
self.
as_while:
if as_while:
pdx = offset + n_shared_outs
cond = inner_output_storage[pdx][0] == 0
# 5.2. By calling fn() directly instead of calling the aesara
# function, it is possible that the updates have not been
# performed. Perform the updates if needed.
offset_out = len(inner_output_storage) - 1
if getattr(fn, 'need_update_inputs', True):
# Update the inputs that have an update function
for inp, storage in zip(self.fn.maker.expanded_inputs[::-1],
self.fn.input_storage[::-1]):
if inp.update is not None:
storage.data = inner_output_storage[offset_out][0].data
if need_update_inputs:
offset_out = len(inner_output_storage) - 1
for needs_update, storage in zip(inner_input_needs_update[::-1],
inner_input_storage[::-1]):
if needs_update:
storage[0] = inner_output_storage[offset_out][0]
offset_out -= 1
offset_out = 0
...
...
@@ -437,12 +439,11 @@ def perform(
# 5.3 Copy over the values for mit_mot outputs
mitmot_inp_offset = 0
mitmot_out_idx = 0
for j in xrange(
self.
n_mit_mot):
for k in
self.
mit_mot_out_slices[j]:
for j in xrange(n_mit_mot):
for k in mit_mot_out_slices[j]:
if mitmots_preallocated[<unsigned int>mitmot_out_idx]:
# This output tap has been preallocated.
inp_idx = (mitmot_inp_offset +
self.tap_array[j].index(k))
inp_idx = (mitmot_inp_offset + tap_array[j].index(k))
# Verify whether the input points to the same data as
# it did before the execution of the inner function.
...
...
@@ -473,7 +474,7 @@ def perform(
mitmot_out_idx += 1
mitmot_inp_offset += len(
self.
tap_array[j])
mitmot_inp_offset += len(tap_array[j])
# 5.4 Copy over the values for mit_sot/sit_sot outputs
begin = n_mit_mot
...
...
@@ -519,7 +520,7 @@ def perform(
outer_outputs[j][0].shape[0] < store_steps[j] or
outer_outputs[j][0].shape[1:] != shape[1:] or
outer_outputs[j][0].dtype != dtype ):
outer_outputs[j][0] = n
ode.outputs[j].type.value_zeros(shape
)
outer_outputs[j][0] = n
umpy.zeros(shape, dtype=output_dtypes[j]
)
elif outer_outputs[j][0].shape[0] != store_steps[j]:
outer_outputs[j][0] = outer_outputs[j][0][:store_steps[j]]
outer_outputs[j][0][pos[j]] = inner_output_storage[jout][0]
...
...
@@ -581,23 +582,23 @@ def perform(
# This way, there will be no information overwritten
# before it is read (as it used to happen).
shape = (pdx,)+ outer_outputs[idx][0].shape[1:]
tmp = node.outputs[idx].type.value_zeros(shape)
tmp = numpy.zeros(shape, dtype=output_dtypes[idx])
tmp[:] = outer_outputs[idx][0][:pdx]
outer_outputs[idx][0][:store_steps[idx]-pdx] = outer_outputs[idx][0][pdx:]
outer_outputs[idx][0][store_steps[idx]-pdx:] = tmp
else:
shape = (store_steps[idx]-pdx,) + outer_outputs[idx][0].shape[1:]
tmp = n
ode.outputs[idx].type.value_zeros(shape
)
tmp = n
umpy.zeros(shape, dtype=output_dtypes[idx]
)
tmp[:] = outer_outputs[idx][0][pdx:]
outer_outputs[idx][0][store_steps[idx]-pdx:] = outer_outputs[idx][0][:pdx]
outer_outputs[idx][0][:store_steps[idx]-pdx] = tmp
# This would normally happen only when doing truncated
# backpropagation through time. In such a scenario Scan is
# expected to return 0 for all entries for which the gradient is
# not actually computed
elif store_steps[idx] > i -
self.
mintaps[idx]:
outer_outputs[idx][0][i
-self.
mintaps[idx]:] = 0
elif store_steps[idx] > i - mintaps[idx]:
outer_outputs[idx][0][i
-
mintaps[idx]:] = 0
# This is a fix for a bug introduced by while. If you say
# you want to loop up to a condition, you expect the output
...
...
@@ -623,17 +624,4 @@ def perform(
for s in inner_output_storage:
s[0] = None
t_call = time.time() - t0_call
if hasattr(fnct.maker, 'profile'):
profile = fnct.maker.profile
if type(profile) is not bool and profile:
profile.vm_call_time += t_fn
profile.callcount += 1
profile.nbsteps += n_steps
profile.call_time += t_call
if hasattr(fn, 'update_profile'):
fn.update_profile(profile)
self.t_call = t_call
self.t_fn = t_fn
return t_fn
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论