Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2ee85105
提交
2ee85105
authored
1月 21, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
1月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove references to Op and Apply objects in Scan's Cython code
上级
4ca744f0
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
66 行增加
和
64 行删除
+66
-64
scan_perform.c
aesara/scan/c_code/scan_perform.c
+0
-0
op.py
aesara/scan/op.py
+31
-17
scan_perform.pyx
aesara/scan/scan_perform.pyx
+35
-47
没有找到文件。
aesara/scan/c_code/scan_perform.c
浏览文件 @
2ee85105
This source diff could not be displayed because it is too large. You can
view the blob
instead.
aesara/scan/op.py
浏览文件 @
2ee85105
...
@@ -1362,19 +1362,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1362,19 +1362,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
try
:
try
:
if
impl
==
"py"
:
if
impl
==
"py"
:
raise
MissingGXX
raise
MissingGXX
cython_mintaps
=
np
.
asarray
(
self
.
mintaps
,
dtype
=
"int32"
)
cython_mintaps
=
np
.
asarray
(
self
.
mintaps
,
dtype
=
"int32"
)
cython_tap_array_len
=
np
.
asarray
(
[
len
(
x
)
for
x
in
self
.
tap_array
],
dtype
=
"int32"
tap_array_len
=
tuple
(
len
(
x
)
for
x
in
self
.
tap_array
)
)
if
len
(
self
.
tap_array
)
==
0
:
d1
=
0
else
:
d1
=
np
.
max
(
cython_tap_array_len
)
d0
=
len
(
self
.
tap_array
)
cython_tap_array
=
np
.
zeros
((
d0
,
d1
),
dtype
=
"int32"
)
for
_d0
in
range
(
d0
):
for
_d1
in
range
(
cython_tap_array_len
[
_d0
]):
cython_tap_array
[
_d0
,
_d1
]
=
self
.
tap_array
[
_d0
][
_d1
]
cython_mit_mot_out_nslices
=
np
.
asarray
(
cython_mit_mot_out_nslices
=
np
.
asarray
(
[
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
],
dtype
=
"int32"
[
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
],
dtype
=
"int32"
)
)
...
@@ -1411,10 +1403,19 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1411,10 +1403,19 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_input_storage
=
[
s
.
storage
for
s
in
self
.
fn
.
input_storage
]
inner_input_storage
=
[
s
.
storage
for
s
in
self
.
fn
.
input_storage
]
inner_output_storage
=
[
s
.
storage
for
s
in
self
.
fn
.
output_storage
]
inner_output_storage
=
[
s
.
storage
for
s
in
self
.
fn
.
output_storage
]
inner_input_needs_update
=
[
inp
.
update
is
not
None
for
inp
in
self
.
fn
.
maker
.
expanded_inputs
]
output_dtypes
=
[
getattr
(
out
,
"dtype"
,
None
)
for
out
in
node
.
outputs
]
from
.
import
scan_perform_ext
from
.
import
scan_perform_ext
def
p
(
node
,
inputs
,
outputs
):
def
p
(
node
,
inputs
,
outputs
):
return
scan_perform_ext
.
perform
(
t0_call
=
time
.
perf_counter
()
t_fn
=
scan_perform_ext
.
perform
(
self
.
n_shared_outs
,
self
.
n_shared_outs
,
self
.
n_mit_mot_outs
,
self
.
n_mit_mot_outs
,
self
.
n_seqs
,
self
.
n_seqs
,
...
@@ -1424,8 +1425,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1424,8 +1425,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self
.
n_nit_sot
,
self
.
n_nit_sot
,
self
.
as_while
,
self
.
as_while
,
cython_mintaps
,
cython_mintaps
,
cython_
tap_array
,
self
.
tap_array
,
cython_
tap_array_len
,
tap_array_len
,
cython_vector_seqs
,
cython_vector_seqs
,
cython_vector_outs
,
cython_vector_outs
,
cython_mit_mot_out_slices
,
cython_mit_mot_out_slices
,
...
@@ -1434,14 +1435,27 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1434,14 +1435,27 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_outs_is_tensor
,
cython_outs_is_tensor
,
inner_input_storage
,
inner_input_storage
,
inner_output_storage
,
inner_output_storage
,
getattr
(
self
.
fn
.
fn
,
"need_update_inputs"
,
True
),
inner_input_needs_update
,
self
.
fn
,
self
.
fn
,
cython_destroy_map
,
cython_destroy_map
,
inputs
,
inputs
,
outputs
,
outputs
,
self
,
output_dtypes
,
node
,
)
)
t_call
=
time
.
perf_counter
()
-
t0_call
if
hasattr
(
self
.
fn
.
maker
,
"profile"
):
profile
=
self
.
fn
.
maker
.
profile
if
type
(
profile
)
is
not
bool
and
profile
:
profile
.
vm_call_time
+=
t_fn
profile
.
callcount
+=
1
profile
.
nbsteps
+=
outputs
[
0
]
profile
.
call_time
+=
t_call
if
hasattr
(
self
.
fn
.
fn
,
"update_profile"
):
self
.
fn
.
fn
.
update_profile
(
profile
)
except
(
ImportError
,
MissingGXX
):
except
(
ImportError
,
MissingGXX
):
p
=
self
.
perform
p
=
self
.
perform
...
...
aesara/scan/scan_perform.pyx
浏览文件 @
2ee85105
...
@@ -71,8 +71,8 @@ def perform(
...
@@ -71,8 +71,8 @@ def perform(
unsigned int n_nit_sot,
unsigned int n_nit_sot,
bint as_while,
bint as_while,
numpy.ndarray[numpy.int32_t,ndim=1] mintaps,
numpy.ndarray[numpy.int32_t,ndim=1] mintaps,
numpy.ndarray[numpy.int32_t,ndim=2]
tap_array,
tuple
tap_array,
numpy.ndarray[numpy.int32_t,ndim=1]
tap_array_len,
tuple
tap_array_len,
numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs,
numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs,
numpy.ndarray[numpy.int32_t,ndim=1] vector_outs,
numpy.ndarray[numpy.int32_t,ndim=1] vector_outs,
numpy.ndarray[numpy.int32_t,ndim=2] mit_mot_out_slices,
numpy.ndarray[numpy.int32_t,ndim=2] mit_mot_out_slices,
...
@@ -81,12 +81,14 @@ def perform(
...
@@ -81,12 +81,14 @@ def perform(
numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor,
numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor,
list inner_input_storage,
list inner_input_storage,
list inner_output_storage,
list inner_output_storage,
bint need_update_inputs,
list inner_input_needs_update,
fnct,
fnct,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
list outer_inputs,
list outer_inputs,
list outer_outputs,
list outer_outputs,
self
,
list output_dtypes
,
node
):
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -110,13 +112,13 @@ def perform(
...
@@ -110,13 +112,13 @@ def perform(
away input tap from current position. For example, if the taps where [-2,
away input tap from current position. For example, if the taps where [-2,
-5, -9], the mintap would be -9. For sit_sot this is always -1 since
-5, -9], the mintap would be -9. For sit_sot this is always -1 since
is the only allowed tap.
is the only allowed tap.
tap_array
: int32 ndarray( can be replaced by a list of list in python if better)
tap_array
For each of the mit_mot, mit_sot, sit_sot (the first dimension) says
For each of the mit_mot, mit_sot, sit_sot (the first dimension) says
which are the corresponding input taps. While this is a matrix, not all
which are the corresponding input taps. While this is a matrix, not all
values in a row are needed and tap_array_len is there to say up to
values in a row are needed and tap_array_len is there to say up to
which entry we are dealing with valid taps ( afterwards there are
which entry we are dealing with valid taps ( afterwards there are
just 0s to ensure the fix format)
just 0s to ensure the fix format)
tap_array_len
: int32 ndarray( can be replaced by a list if better)
tap_array_len
For each of the mit_mot, mit_sot, sit_sot says how many input taps
For each of the mit_mot, mit_sot, sit_sot says how many input taps
each has. For sit_sot this will always be 1.
each has. For sit_sot this will always be 1.
vector_seqs: int32 ndarray (can be replaced by a list of bools if better)
vector_seqs: int32 ndarray (can be replaced by a list of bools if better)
...
@@ -138,6 +140,10 @@ def perform(
...
@@ -138,6 +140,10 @@ def perform(
The storage locations for the inner-function's inputs.
The storage locations for the inner-function's inputs.
inner_output_storage
inner_output_storage
The storage locations for the inner-function's outputs.
The storage locations for the inner-function's outputs.
need_update_inputs
A boolean indicating whether or not inner inputs need to be updated.
inner_input_needs_update
A list of booleans indicating which inner inputs need to be updated.
fnct: Function
fnct: Function
The compiled Aesara inner-function object.
The compiled Aesara inner-function object.
destroy_map
destroy_map
...
@@ -149,15 +155,13 @@ def perform(
...
@@ -149,15 +155,13 @@ def perform(
This is where we need to copy our outputs ( we don't return the
This is where we need to copy our outputs ( we don't return the
results, though we can change the code such that we return, and
results, though we can change the code such that we return, and
figure things out on the outside - python)
figure things out on the outside - python)
self: python object
output_dtypes
The scan op itself. I only use it to attach to it some timing
The dtypes for each output.
information .. but I don;t need to.
"""
"""
# 1. Unzip the number of steps and sequences. If number of steps is
# 1. Unzip the number of steps and sequences. If number of steps is
# negative flip sequences around, and make n_steps positive
# negative flip sequences around, and make n_steps positive
t0_call = time.time()
cdef unsigned int t_fn = 0
t_fn = 0
cdef unsigned int n_steps = outer_inputs[0].item()
cdef unsigned int n_steps = outer_inputs[0].item()
cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot
cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot
cdef unsigned int seqs_arg_offset = n_seqs + 1
cdef unsigned int seqs_arg_offset = n_seqs + 1
...
@@ -191,7 +195,6 @@ def perform(
...
@@ -191,7 +195,6 @@ def perform(
n_sit_sot + n_nit_sot +
n_sit_sot + n_nit_sot +
n_shared_outs)
n_shared_outs)
if n_steps < 0:
if n_steps < 0:
# History, in the past, this was used for backward
# History, in the past, this was used for backward
# scan. Now we reverse the inputs outside of scan.
# scan. Now we reverse the inputs outside of scan.
...
@@ -249,7 +252,7 @@ def perform(
...
@@ -249,7 +252,7 @@ def perform(
# (The answer is that you shouldn't have a `node` object to
# (The answer is that you shouldn't have a `node` object to
# access, because it's not going to produce a very efficient
# access, because it's not going to produce a very efficient
# Cython function!)
# Cython function!)
outer_outputs[idx][0] = n
ode.outputs[idx].type.value_zeros(0
)
outer_outputs[idx][0] = n
umpy.zeros(0, dtype=output_dtypes[idx]
)
else:
else:
outer_outputs[idx][0] = None
outer_outputs[idx][0] = None
return
return
...
@@ -297,14 +300,14 @@ def perform(
...
@@ -297,14 +300,14 @@ def perform(
for idx in range(n_outs):
for idx in range(n_outs):
if vector_outs[idx] == 1:
if vector_outs[idx] == 1:
for tdx in range(tap_array_len[idx]):
for tdx in range(tap_array_len[idx]):
tap = tap_array[idx
,
tdx]
tap = tap_array[idx
][
tdx]
_idx = (pos[idx]+tap)%store_steps[idx]
_idx = (pos[idx]+tap)%store_steps[idx]
inner_input_storage[offset][0] =\
inner_input_storage[offset][0] =\
outer_outputs[idx][0][_idx:<unsigned int>(_idx+1)].reshape(())
outer_outputs[idx][0][_idx:<unsigned int>(_idx+1)].reshape(())
offset += 1
offset += 1
else:
else:
for tdx in range(tap_array_len[idx]):
for tdx in range(tap_array_len[idx]):
tap = tap_array[idx
,
tdx]
tap = tap_array[idx
][
tdx]
_idx = (pos[idx]+tap)%store_steps[idx]
_idx = (pos[idx]+tap)%store_steps[idx]
inner_input_storage[offset][0] = outer_outputs[idx][0][_idx]
inner_input_storage[offset][0] = outer_outputs[idx][0][_idx]
offset += 1
offset += 1
...
@@ -416,20 +419,19 @@ def perform(
...
@@ -416,20 +419,19 @@ def perform(
dt_fn = time.time() - t0_fn
dt_fn = time.time() - t0_fn
t_fn += dt_fn
t_fn += dt_fn
if
self.
as_while:
if as_while:
pdx = offset + n_shared_outs
pdx = offset + n_shared_outs
cond = inner_output_storage[pdx][0] == 0
cond = inner_output_storage[pdx][0] == 0
# 5.2. By calling fn() directly instead of calling the aesara
# 5.2. By calling fn() directly instead of calling the aesara
# function, it is possible that the updates have not been
# function, it is possible that the updates have not been
# performed. Perform the updates if needed.
# performed. Perform the updates if needed.
if need_update_inputs:
offset_out = len(inner_output_storage) - 1
offset_out = len(inner_output_storage) - 1
if getattr(fn, 'need_update_inputs', True):
for needs_update, storage in zip(inner_input_needs_update[::-1],
# Update the inputs that have an update function
inner_input_storage[::-1]):
for inp, storage in zip(self.fn.maker.expanded_inputs[::-1],
if needs_update:
self.fn.input_storage[::-1]):
storage[0] = inner_output_storage[offset_out][0]
if inp.update is not None:
storage.data = inner_output_storage[offset_out][0].data
offset_out -= 1
offset_out -= 1
offset_out = 0
offset_out = 0
...
@@ -437,12 +439,11 @@ def perform(
...
@@ -437,12 +439,11 @@ def perform(
# 5.3 Copy over the values for mit_mot outputs
# 5.3 Copy over the values for mit_mot outputs
mitmot_inp_offset = 0
mitmot_inp_offset = 0
mitmot_out_idx = 0
mitmot_out_idx = 0
for j in xrange(
self.
n_mit_mot):
for j in xrange(n_mit_mot):
for k in
self.
mit_mot_out_slices[j]:
for k in mit_mot_out_slices[j]:
if mitmots_preallocated[<unsigned int>mitmot_out_idx]:
if mitmots_preallocated[<unsigned int>mitmot_out_idx]:
# This output tap has been preallocated.
# This output tap has been preallocated.
inp_idx = (mitmot_inp_offset +
inp_idx = (mitmot_inp_offset + tap_array[j].index(k))
self.tap_array[j].index(k))
# Verify whether the input points to the same data as
# Verify whether the input points to the same data as
# it did before the execution of the inner function.
# it did before the execution of the inner function.
...
@@ -473,7 +474,7 @@ def perform(
...
@@ -473,7 +474,7 @@ def perform(
mitmot_out_idx += 1
mitmot_out_idx += 1
mitmot_inp_offset += len(
self.
tap_array[j])
mitmot_inp_offset += len(tap_array[j])
# 5.4 Copy over the values for mit_sot/sit_sot outputs
# 5.4 Copy over the values for mit_sot/sit_sot outputs
begin = n_mit_mot
begin = n_mit_mot
...
@@ -519,7 +520,7 @@ def perform(
...
@@ -519,7 +520,7 @@ def perform(
outer_outputs[j][0].shape[0] < store_steps[j] or
outer_outputs[j][0].shape[0] < store_steps[j] or
outer_outputs[j][0].shape[1:] != shape[1:] or
outer_outputs[j][0].shape[1:] != shape[1:] or
outer_outputs[j][0].dtype != dtype ):
outer_outputs[j][0].dtype != dtype ):
outer_outputs[j][0] = n
ode.outputs[j].type.value_zeros(shape
)
outer_outputs[j][0] = n
umpy.zeros(shape, dtype=output_dtypes[j]
)
elif outer_outputs[j][0].shape[0] != store_steps[j]:
elif outer_outputs[j][0].shape[0] != store_steps[j]:
outer_outputs[j][0] = outer_outputs[j][0][:store_steps[j]]
outer_outputs[j][0] = outer_outputs[j][0][:store_steps[j]]
outer_outputs[j][0][pos[j]] = inner_output_storage[jout][0]
outer_outputs[j][0][pos[j]] = inner_output_storage[jout][0]
...
@@ -581,23 +582,23 @@ def perform(
...
@@ -581,23 +582,23 @@ def perform(
# This way, there will be no information overwritten
# This way, there will be no information overwritten
# before it is read (as it used to happen).
# before it is read (as it used to happen).
shape = (pdx,)+ outer_outputs[idx][0].shape[1:]
shape = (pdx,)+ outer_outputs[idx][0].shape[1:]
tmp = numpy.zeros(shape, dtype=output_dtypes[idx])
tmp = node.outputs[idx].type.value_zeros(shape)
tmp[:] = outer_outputs[idx][0][:pdx]
tmp[:] = outer_outputs[idx][0][:pdx]
outer_outputs[idx][0][:store_steps[idx]-pdx] = outer_outputs[idx][0][pdx:]
outer_outputs[idx][0][:store_steps[idx]-pdx] = outer_outputs[idx][0][pdx:]
outer_outputs[idx][0][store_steps[idx]-pdx:] = tmp
outer_outputs[idx][0][store_steps[idx]-pdx:] = tmp
else:
else:
shape = (store_steps[idx]-pdx,) + outer_outputs[idx][0].shape[1:]
shape = (store_steps[idx]-pdx,) + outer_outputs[idx][0].shape[1:]
tmp = n
ode.outputs[idx].type.value_zeros(shape
)
tmp = n
umpy.zeros(shape, dtype=output_dtypes[idx]
)
tmp[:] = outer_outputs[idx][0][pdx:]
tmp[:] = outer_outputs[idx][0][pdx:]
outer_outputs[idx][0][store_steps[idx]-pdx:] = outer_outputs[idx][0][:pdx]
outer_outputs[idx][0][store_steps[idx]-pdx:] = outer_outputs[idx][0][:pdx]
outer_outputs[idx][0][:store_steps[idx]-pdx] = tmp
outer_outputs[idx][0][:store_steps[idx]-pdx] = tmp
# This would normally happen only when doing truncated
# This would normally happen only when doing truncated
# backpropagation through time. In such a scenario Scan is
# backpropagation through time. In such a scenario Scan is
# expected to return 0 for all entries for which the gradient is
# expected to return 0 for all entries for which the gradient is
# not actually computed
# not actually computed
elif store_steps[idx] > i -
self.
mintaps[idx]:
elif store_steps[idx] > i - mintaps[idx]:
outer_outputs[idx][0][i
-self.
mintaps[idx]:] = 0
outer_outputs[idx][0][i
-
mintaps[idx]:] = 0
# This is a fix for a bug introduced by while. If you say
# This is a fix for a bug introduced by while. If you say
# you want to loop up to a condition, you expect the output
# you want to loop up to a condition, you expect the output
...
@@ -623,17 +624,4 @@ def perform(
...
@@ -623,17 +624,4 @@ def perform(
for s in inner_output_storage:
for s in inner_output_storage:
s[0] = None
s[0] = None
t_call = time.time() - t0_call
return t_fn
if hasattr(fnct.maker, 'profile'):
profile = fnct.maker.profile
if type(profile) is not bool and profile:
profile.vm_call_time += t_fn
profile.callcount += 1
profile.nbsteps += n_steps
profile.call_time += t_call
if hasattr(fn, 'update_profile'):
fn.update_profile(profile)
self.t_call = t_call
self.t_fn = t_fn
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论